第 32 章

DNS 服务器:UDP 网络编程

第三十二章:DNS 服务器:UDP 网络编程

你每天使用的互联网,每一次打开浏览器输入 github.com,背后都发生了一件让人着迷的事情:你的操作系统向某台服务器发送了一个 53 字节左右的 UDP 数据包,几十毫秒后收到回应,然后才建立起 TCP 连接,开始传输网页内容。这个过程叫做 DNS 解析。

DNS(Domain Name System,域名系统)是互联网的电话簿。它将人类可读的域名翻译成机器可读的 IP 地址。没有 DNS,你需要记住 140.82.114.4 才能访问 GitHub。但 DNS 远不止是"查字典"那么简单——它是一个分布式的、层级化的、高度可扩展的系统,每天处理着全球数以万亿计的查询请求。

本章的目标是深入 DNS 的每一个细节,然后用 Go 从零构建一个 DNS 代理服务器。这不仅是学习 UDP 编程的最佳场景,也是理解互联网底层协议设计哲学的窗口。


Level 1 · 你需要知道的

DNS 的工作原理:一次完整的查询旅程

当你在浏览器输入 www.google.com 时,发生了以下事情:

第一步:检查本地缓存

操作系统首先检查自身的 DNS 缓存。如果最近访问过这个域名,且缓存的 TTL(Time To Live,生存时间)未过期,直接返回缓存结果,查询结束。

第二步:查询递归解析器

缓存未命中,操作系统向配置的 DNS 服务器(通常是你的路由器或 ISP 提供的服务器,或你手动配置的 8.8.8.8)发送查询请求。这台服务器叫做递归解析器(Recursive Resolver)。它代替你完成整个查询过程。

第三步:递归解析器询问根服务器

递归解析器也可能没有缓存。它向 DNS 根服务器发送查询。全球有 13 个根服务器 IP(A 到 M),但每个 IP 背后都有数百台实体机器提供 Anycast 服务。根服务器不知道 www.google.com 的 IP,但它知道谁负责 .com 域——它返回 .com 顶级域(TLD)的权威服务器地址。

第四步:询问 TLD 服务器

递归解析器向 .com TLD 服务器发送查询。TLD 服务器不知道 www.google.com 的具体 IP,但它知道 google.com 的权威服务器是哪些,返回这些服务器的地址。

第五步:询问权威服务器

递归解析器向 google.com 的权威服务器发送查询。权威服务器(Authoritative Server)是最终答案的来源——它持有该域名的 DNS 记录,直接返回 www.google.com 对应的 IP 地址。

第六步:结果逐级返回并缓存

答案沿路径返回,每一层都按照 TTL 值缓存结果,最终到达你的浏览器。整个过程通常在 100ms 以内完成。

递归解析器 vs 权威服务器

这是两个经常被混淆的概念:

类型 作用 例子
递归解析器 代替客户端完成递归查询,缓存结果 8.8.8.8(Google)、1.1.1.1(Cloudflare)、你的路由器
权威服务器 持有域名最终记录,直接回答 Cloudflare DNS、AWS Route 53、各大域名注册商

递归解析器是"勤劳的中间人",它做实际的查询工作,但不持有原始记录。权威服务器是"最终权威",它只回答自己管辖域的查询,不做递归。

为什么 DNS 用 UDP?

DNS 主要使用 UDP 协议,端口 53。为什么不用 TCP?

UDP 的优势:

TCP 的使用场景:

这就是为什么每一个 Go 开发者都应该理解 DNS:它是你写的每一个网络程序的基础设施。当你的服务在生产环境出现 DNS 解析超时,当你需要实现服务发现,当你要做广告拦截器或安全代理,都需要深入理解 DNS 的工作原理。


Level 2 · 原理深入

DNS 报文格式

DNS 报文格式由 RFC 1035 定义,是一个精心设计的紧凑二进制格式。理解它是实现 DNS 相关功能的基础。

一个 DNS 报文由五个部分组成:

+---------------------+
|        Header       |  12 字节固定长度
+---------------------+
|       Question      |  查询的域名和类型
+---------------------+
|        Answer       |  回答记录
+---------------------+
|      Authority      |  权威服务器记录
+---------------------+
|      Additional     |  附加记录
+---------------------+

Header 结构(12 字节):

 0  1  2  3  4  5  6  7  8  9  10 11 12 13 14 15
+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
|                      ID                       |   2字节:事务ID,随机生成,用于匹配请求和响应
+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
|QR|   Opcode  |AA|TC|RD|RA|   Z    |   RCODE   |   2字节:标志位
+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
|                    QDCOUNT                    |   2字节:问题数量
+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
|                    ANCOUNT                    |   2字节:回答记录数量
+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
|                    NSCOUNT                    |   2字节:权威记录数量
+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
|                    ARCOUNT                    |   2字节:附加记录数量
+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+

关键标志位含义:

域名压缩编码

DNS 中域名的编码方式非常有趣。www.google.com 被编码为:

3 w w w 6 g o o g l e 3 c o m 0

每个标签(label)前面跟着一个长度字节,最后以 0x00 结尾。这是"长度前缀"编码。

但 DNS 还有一个节省空间的技巧——指针压缩。当报文中同一个域名出现多次时,第二次可以用一个两字节指针代替,指向第一次出现的位置。指针的特征是高两位为 11,即 0xC0 开头:

0xC0 0x0C  →  指向报文偏移 12 处的域名

实现 DNS 解析器时,必须正确处理这个递归指针,否则解析会出错甚至造成无限循环(恶意报文可构造循环指针)。

net.UDPConn:低级 UDP 编程

Go 的 net 包提供了两层 UDP 编程接口:

高层接口net.Dial("udp", ...) 返回 net.Conn,适合客户端场景,隐藏了远端地址的细节。

低层接口net.ListenUDP(...) 返回 *net.UDPConn,适合服务器场景,每次接收数据时都能获得发送方的地址。

DNS 服务器必须用低层接口,因为它需要记住每个查询来自哪个客户端地址,然后将响应发回该地址:

// DNS 服务器的核心:ListenUDP
conn, err := net.ListenUDP("udp", &net.UDPAddr{
    IP:   net.ParseIP("0.0.0.0"),
    Port: 53,
})
// ReadFromUDP 同时返回数据和发送方地址
n, clientAddr, err := conn.ReadFromUDP(buf)
// WriteToUDP 向指定地址发送响应
_, err = conn.WriteToUDP(response, clientAddr)

重要的性能陷阱:UDP 是无连接的,每次 ReadFromUDP 都是一次独立的系统调用,获取一个独立的数据报。与 TCP 流不同,UDP 保证消息边界——你发送 100 字节,对方就收到 100 字节的完整包,不会粘包也不会分片(在 MTU 限制范围内)。

Goroutine-per-Request vs Worker Pool

对于 DNS 服务器,处理每个查询的并发模型选择至关重要:

Goroutine-per-Request 模型:

ReadFromUDP → 为每个查询启动 go func() → 处理 → WriteToUDP

优点:简单,代码直观。缺点:在 DDoS 攻击场景下,可能创建数百万个 goroutine,内存耗尽。每个 goroutine 初始栈 2-8KB,100 万 goroutine 就是 2-8GB 内存。

Worker Pool 模型:

ReadFromUDP → channel → 固定数量 workers → 处理 → WriteToUDP

优点:内存可控,可以做背压(back pressure)。缺点:worker 数量选择影响吞吐量。

对于生产级 DNS 服务器,Worker Pool 是更好的选择,但 worker 数量应该根据 CPU 核数和上游 DNS 延迟来调整。如果上游 DNS 平均延迟 20ms,CPU 8 核,那么合理的 worker 数量大约是 8 * (1000/20) = 400,以保持 CPU 充分利用。

DNS 缓存与 TTL 管理

缓存是 DNS 性能的关键。TTL 是每条 DNS 记录携带的"保鲜期",单位是秒。实现缓存时有几个重要细节:

TTL 的递减问题:当你缓存一条 TTL 为 300 秒的记录时,缓存 150 秒后向客户端返回时,应该返回剩余的 150 秒 TTL,而不是原始的 300 秒。否则客户端可能在记录已经过期后还继续使用。

负缓存(Negative Caching):NXDOMAIN 响应也应该缓存,使用 SOA 记录中的 minimum TTL 值。否则对不存在域名的大量查询会反复打到上游服务器。

并发安全:DNS 缓存会被多个 worker goroutine 并发读写,必须用 sync.RWMutexsync.Map 保护。


Level 3 · 代码实战

构建 DNS 代理/转发器

下面我们构建一个完整的 DNS 代理服务器,支持缓存、上游转发和本地拦截列表。

项目结构:

dnsproxy/
├── main.go
├── server.go      # UDP 服务器和主循环
├── resolver.go    # DNS 查询解析
├── cache.go       # TTL 感知缓存
├── upstream.go    # 上游 DNS 转发
├── blocklist.go   # 广告拦截
└── doh.go         # DNS over HTTPS 上游

DNS 报文解析(resolver.go):

package main

import (
    "encoding/binary"
    "errors"
    "fmt"
    "strings"
)

// DNSHeader 表示 DNS 报文头部
type DNSHeader struct {
    ID      uint16
    Flags   uint16
    QDCount uint16
    ANCount uint16
    NSCount uint16
    ARCount uint16
}

// DNSQuestion 表示一个查询问题
type DNSQuestion struct {
    Name  string
    Type  uint16
    Class uint16
}

// DNSResourceRecord 表示一条资源记录
type DNSResourceRecord struct {
    Name     string
    Type     uint16
    Class    uint16
    TTL      uint32
    RDLength uint16
    RData    []byte
}

// DNSMessage 表示完整的 DNS 报文
type DNSMessage struct {
    Header      DNSHeader
    Questions   []DNSQuestion
    Answers     []DNSResourceRecord
    Authorities []DNSResourceRecord
    Additionals []DNSResourceRecord
    Raw         []byte // 原始字节,用于转发时直接传递
}

// 从字节流解析 DNS 报文头部
func parseHeader(data []byte) (DNSHeader, error) {
    if len(data) < 12 {
        return DNSHeader{}, errors.New("DNS message too short")
    }
    return DNSHeader{
        ID:      binary.BigEndian.Uint16(data[0:2]),
        Flags:   binary.BigEndian.Uint16(data[2:4]),
        QDCount: binary.BigEndian.Uint16(data[4:6]),
        ANCount: binary.BigEndian.Uint16(data[6:8]),
        NSCount: binary.BigEndian.Uint16(data[8:10]),
        ARCount: binary.BigEndian.Uint16(data[10:12]),
    }, nil
}

// 解析 DNS 域名,处理指针压缩
// offset 是当前解析位置,返回 (域名, 下一个字段的偏移量)
func parseDomainName(data []byte, offset int) (string, int, error) {
    var labels []string
    originalOffset := -1 // 遇到指针后保存原始位置

    for {
        if offset >= len(data) {
            return "", 0, errors.New("offset out of bounds")
        }

        length := int(data[offset])

        if length == 0 {
            // 根域名终止符
            offset++
            break
        }

        if length&0xC0 == 0xC0 {
            // 指针压缩:高两位为 11
            if offset+1 >= len(data) {
                return "", 0, errors.New("invalid pointer")
            }
            // 计算指针目标:低 14 位为偏移量
            ptr := int(binary.BigEndian.Uint16(data[offset:offset+2]) & 0x3FFF)

            if originalOffset == -1 {
                // 第一次遇到指针,记录当前位置以便后续继续解析
                originalOffset = offset + 2
            }

            // 防止恶意构造的循环指针
            if ptr >= offset {
                return "", 0, fmt.Errorf("forward pointer not allowed: %d >= %d", ptr, offset)
            }

            offset = ptr
            continue
        }

        if length&0xC0 != 0 {
            return "", 0, fmt.Errorf("invalid label length: %d", length)
        }

        // 普通标签
        offset++
        if offset+length > len(data) {
            return "", 0, errors.New("label exceeds data length")
        }
        labels = append(labels, string(data[offset:offset+length]))
        offset += length
    }

    if originalOffset != -1 {
        offset = originalOffset
    }

    return strings.Join(labels, "."), offset, nil
}

// 解析查询问题节
func parseQuestion(data []byte, offset int) (DNSQuestion, int, error) {
    name, newOffset, err := parseDomainName(data, offset)
    if err != nil {
        return DNSQuestion{}, 0, err
    }

    if newOffset+4 > len(data) {
        return DNSQuestion{}, 0, errors.New("question section too short")
    }

    return DNSQuestion{
        Name:  name,
        Type:  binary.BigEndian.Uint16(data[newOffset : newOffset+2]),
        Class: binary.BigEndian.Uint16(data[newOffset+2 : newOffset+4]),
    }, newOffset + 4, nil
}

// 解析完整 DNS 报文
func parseMessage(data []byte) (*DNSMessage, error) {
    header, err := parseHeader(data)
    if err != nil {
        return nil, err
    }

    msg := &DNSMessage{
        Header: header,
        Raw:    data,
    }

    offset := 12
    for i := 0; i < int(header.QDCount); i++ {
        q, newOffset, err := parseQuestion(data, offset)
        if err != nil {
            return nil, fmt.Errorf("parsing question %d: %w", i, err)
        }
        msg.Questions = append(msg.Questions, q)
        offset = newOffset
    }

    // Answer/Authority/Additional 记录解析类似,省略重复代码
    return msg, nil
}

// 构建域名查询的缓存键
func cacheKey(name string, qtype uint16) string {
    return fmt.Sprintf("%s:%d", strings.ToLower(name), qtype)
}

// QueryIsRecursionDesired 检查 RD 标志
func (h DNSHeader) IsRecursionDesired() bool {
    return h.Flags&0x0100 != 0
}

TTL 感知缓存(cache.go):

package main

import (
    "sync"
    "time"
)

// CacheEntry 存储缓存的 DNS 响应和过期时间
type CacheEntry struct {
    Response  []byte    // 原始 DNS 响应字节
    ExpiresAt time.Time // 基于原始 TTL 计算的过期时间
    OriginalTTL uint32  // 响应中最小的 TTL 值
}

// DNSCache 是线程安全的 DNS 响应缓存
type DNSCache struct {
    mu      sync.RWMutex
    entries map[string]*CacheEntry
}

func NewDNSCache() *DNSCache {
    c := &DNSCache{
        entries: make(map[string]*CacheEntry),
    }
    // 启动后台清理 goroutine
    go c.cleanup()
    return c
}

// Get 返回缓存响应,并调整报文中的 TTL 值为剩余 TTL
func (c *DNSCache) Get(key string) ([]byte, bool) {
    c.mu.RLock()
    entry, ok := c.entries[key]
    c.mu.RUnlock()

    if !ok {
        return nil, false
    }

    remaining := time.Until(entry.ExpiresAt)
    if remaining <= 0 {
        c.mu.Lock()
        delete(c.entries, key)
        c.mu.Unlock()
        return nil, false
    }

    // 深拷贝响应,然后更新所有 TTL 字段为剩余时间
    response := make([]byte, len(entry.Response))
    copy(response, entry.Response)

    remainingSeconds := uint32(remaining.Seconds())
    updateTTLInResponse(response, remainingSeconds)

    return response, true
}

// updateTTLInResponse 修改响应报文中的所有资源记录 TTL 为剩余值
// 注意:这需要跳过 Header 和 Question 部分,定位到 Answer 记录
func updateTTLInResponse(response []byte, ttl uint32) {
    if len(response) < 12 {
        return
    }

    anCount := int(binary.BigEndian.Uint16(response[6:8]))
    nsCount := int(binary.BigEndian.Uint16(response[8:10]))
    arCount := int(binary.BigEndian.Uint16(response[10:12]))
    totalRR := anCount + nsCount + arCount

    // 跳过 Header(12字节)和 Question 部分
    offset := 12
    qdCount := int(binary.BigEndian.Uint16(response[4:6]))
    for i := 0; i < qdCount; i++ {
        _, newOffset, err := parseDomainName(response, offset)
        if err != nil {
            return
        }
        offset = newOffset + 4 // 跳过 QTYPE 和 QCLASS
    }

    // 更新每条资源记录的 TTL
    for i := 0; i < totalRR; i++ {
        _, newOffset, err := parseDomainName(response, offset)
        if err != nil {
            return
        }
        offset = newOffset + 4 // 跳过 TYPE 和 CLASS

        if offset+6 > len(response) {
            return
        }

        // TTL 在 TYPE+CLASS 之后的 4 字节
        binary.BigEndian.PutUint32(response[offset:offset+4], ttl)
        rdLength := int(binary.BigEndian.Uint16(response[offset+4 : offset+6]))
        offset += 6 + rdLength
    }
}

// Set 缓存一个 DNS 响应
func (c *DNSCache) Set(key string, response []byte, ttl uint32) {
    if ttl == 0 {
        return // TTL 为 0 的记录不缓存
    }

    c.mu.Lock()
    c.entries[key] = &CacheEntry{
        Response:    response,
        ExpiresAt:   time.Now().Add(time.Duration(ttl) * time.Second),
        OriginalTTL: ttl,
    }
    c.mu.Unlock()
}

// cleanup 定期清理过期缓存条目
func (c *DNSCache) cleanup() {
    ticker := time.NewTicker(30 * time.Second)
    for range ticker.C {
        now := time.Now()
        c.mu.Lock()
        for k, v := range c.entries {
            if now.After(v.ExpiresAt) {
                delete(c.entries, k)
            }
        }
        c.mu.Unlock()
    }
}

上游 UDP 转发(upstream.go):

package main

import (
    "fmt"
    "net"
    "time"
)

// UpstreamResolver 向上游 DNS 服务器转发查询
type UpstreamResolver struct {
    servers []string // 上游 DNS 服务器地址列表,如 ["8.8.8.8:53", "1.1.1.1:53"]
    timeout time.Duration
}

func NewUpstreamResolver(servers []string) *UpstreamResolver {
    return &UpstreamResolver{
        servers: servers,
        timeout: 5 * time.Second,
    }
}

// Query 向上游转发查询,返回原始响应字节
func (r *UpstreamResolver) Query(query []byte) ([]byte, error) {
    var lastErr error

    for _, server := range r.servers {
        resp, err := r.queryServer(server, query)
        if err != nil {
            lastErr = err
            continue
        }
        return resp, nil
    }

    return nil, fmt.Errorf("all upstream servers failed, last error: %w", lastErr)
}

func (r *UpstreamResolver) queryServer(server string, query []byte) ([]byte, error) {
    conn, err := net.DialTimeout("udp", server, r.timeout)
    if err != nil {
        return nil, fmt.Errorf("connecting to %s: %w", server, err)
    }
    defer conn.Close()

    conn.SetDeadline(time.Now().Add(r.timeout))

    _, err = conn.Write(query)
    if err != nil {
        return nil, fmt.Errorf("sending query to %s: %w", server, err)
    }

    buf := make([]byte, 4096) // 支持 EDNS0 的大响应
    n, err := conn.Read(buf)
    if err != nil {
        return nil, fmt.Errorf("reading response from %s: %w", server, err)
    }

    return buf[:n], nil
}

DNS over HTTPS 上游(doh.go):

package main

import (
    "bytes"
    "fmt"
    "io"
    "net/http"
    "time"
)

// DoHResolver 通过 HTTPS 向上游发送 DNS 查询(RFC 8484)
type DoHResolver struct {
    url    string
    client *http.Client
}

func NewDoHResolver(url string) *DoHResolver {
    return &DoHResolver{
        url: url,
        client: &http.Client{
            Timeout: 5 * time.Second,
        },
    }
}

// Query 使用 DNS Wireformat over HTTPS(application/dns-message)发送查询
func (r *DoHResolver) Query(query []byte) ([]byte, error) {
    req, err := http.NewRequest("POST", r.url, bytes.NewReader(query))
    if err != nil {
        return nil, err
    }

    // RFC 8484 规定的 Content-Type
    req.Header.Set("Content-Type", "application/dns-message")
    req.Header.Set("Accept", "application/dns-message")

    resp, err := r.client.Do(req)
    if err != nil {
        return nil, fmt.Errorf("DoH request failed: %w", err)
    }
    defer resp.Body.Close()

    if resp.StatusCode != http.StatusOK {
        return nil, fmt.Errorf("DoH server returned status %d", resp.StatusCode)
    }

    return io.ReadAll(io.LimitReader(resp.Body, 65535))
}

广告拦截列表(blocklist.go):

package main

import (
    "bufio"
    "os"
    "strings"
    "sync"
)

// Blocklist 持有被拦截的域名集合
type Blocklist struct {
    mu      sync.RWMutex
    domains map[string]struct{}
}

func NewBlocklist() *Blocklist {
    return &Blocklist{
        domains: make(map[string]struct{}),
    }
}

// LoadFromFile 从 hosts 格式文件加载拦截列表
// 支持格式:0.0.0.0 ads.example.com 或 # 注释
func (b *Blocklist) LoadFromFile(path string) error {
    f, err := os.Open(path)
    if err != nil {
        return err
    }
    defer f.Close()

    b.mu.Lock()
    defer b.mu.Unlock()

    scanner := bufio.NewScanner(f)
    for scanner.Scan() {
        line := strings.TrimSpace(scanner.Text())
        if line == "" || strings.HasPrefix(line, "#") {
            continue
        }
        fields := strings.Fields(line)
        if len(fields) >= 2 {
            // hosts 格式:IP 域名
            domain := strings.ToLower(fields[1])
            b.domains[domain] = struct{}{}
        }
    }

    return scanner.Err()
}

// IsBlocked 检查域名是否被拦截,同时检查父域名
func (b *Blocklist) IsBlocked(name string) bool {
    name = strings.ToLower(strings.TrimSuffix(name, "."))

    b.mu.RLock()
    defer b.mu.RUnlock()

    // 精确匹配
    if _, ok := b.domains[name]; ok {
        return true
    }

    // 检查父域名(支持通配符拦截)
    for {
        idx := strings.Index(name, ".")
        if idx == -1 {
            break
        }
        name = name[idx+1:]
        if _, ok := b.domains[name]; ok {
            return true
        }
    }

    return false
}

// buildNXDOMAINResponse 构建 NXDOMAIN 响应,用于被拦截的域名
func buildNXDOMAINResponse(query []byte) []byte {
    if len(query) < 12 {
        return nil
    }

    response := make([]byte, len(query))
    copy(response, query)

    // 将 QR 置 1(响应)、设置 RCODE 为 3(NXDOMAIN)
    // Flags 在字节 2-3
    flags := binary.BigEndian.Uint16(query[2:4])
    flags |= 0x8000 // QR = 1
    flags |= 0x0003 // RCODE = 3 (NXDOMAIN)
    flags &^= 0x0200 // AA = 0
    binary.BigEndian.PutUint16(response[2:4], flags)

    return response
}

主服务器逻辑(server.go):

package main

import (
    "encoding/binary"
    "log"
    "net"
    "sync"
)

// DNSProxy 是完整的 DNS 代理服务器
type DNSProxy struct {
    conn      *net.UDPConn
    cache     *DNSCache
    upstream  *UpstreamResolver
    doh       *DoHResolver
    blocklist *Blocklist
    workers   int
    jobs      chan dnsJob
}

type dnsJob struct {
    data       []byte
    clientAddr *net.UDPAddr
}

func NewDNSProxy(listenAddr string, workers int) (*DNSProxy, error) {
    addr, err := net.ResolveUDPAddr("udp", listenAddr)
    if err != nil {
        return nil, err
    }

    conn, err := net.ListenUDP("udp", addr)
    if err != nil {
        return nil, err
    }

    // 增大 socket 接收缓冲区,减少高并发时的丢包
    conn.SetReadBuffer(4 * 1024 * 1024)
    conn.SetWriteBuffer(4 * 1024 * 1024)

    return &DNSProxy{
        conn:      conn,
        cache:     NewDNSCache(),
        upstream:  NewUpstreamResolver([]string{"8.8.8.8:53", "1.1.1.1:53"}),
        doh:       NewDoHResolver("https://cloudflare-dns.com/dns-query"),
        blocklist: NewBlocklist(),
        workers:   workers,
        jobs:      make(chan dnsJob, workers*10),
    }, nil
}

func (p *DNSProxy) Start() {
    // 启动 worker pool
    var wg sync.WaitGroup
    for i := 0; i < p.workers; i++ {
        wg.Add(1)
        go func() {
            defer wg.Done()
            p.worker()
        }()
    }

    // 主循环:接收 UDP 数据报
    buf := make([]byte, 4096)
    for {
        n, clientAddr, err := p.conn.ReadFromUDP(buf)
        if err != nil {
            log.Printf("ReadFromUDP error: %v", err)
            continue
        }

        // 复制数据,避免下次 ReadFromUDP 覆盖
        data := make([]byte, n)
        copy(data, buf[:n])

        // 非阻塞发送到 jobs channel,防止慢客户端拖慢接收循环
        select {
        case p.jobs <- dnsJob{data: data, clientAddr: clientAddr}:
        default:
            log.Printf("Worker pool full, dropping query from %s", clientAddr)
        }
    }
}

func (p *DNSProxy) worker() {
    for job := range p.jobs {
        response := p.handleQuery(job.data)
        if response != nil {
            if _, err := p.conn.WriteToUDP(response, job.clientAddr); err != nil {
                log.Printf("WriteToUDP error: %v", err)
            }
        }
    }
}

func (p *DNSProxy) handleQuery(data []byte) []byte {
    msg, err := parseMessage(data)
    if err != nil || len(msg.Questions) == 0 {
        return nil
    }

    q := msg.Questions[0]

    // 检查拦截列表
    if p.blocklist.IsBlocked(q.Name) {
        log.Printf("Blocked: %s", q.Name)
        return buildNXDOMAINResponse(data)
    }

    key := cacheKey(q.Name, q.Type)

    // 检查缓存
    if cached, ok := p.cache.Get(key); ok {
        // 修正响应中的事务 ID 以匹配本次查询
        response := make([]byte, len(cached))
        copy(response, cached)
        binary.BigEndian.PutUint16(response[0:2], msg.Header.ID)
        return response
    }

    // 向上游转发
    upstream, err := p.upstream.Query(data)
    if err != nil {
        log.Printf("Upstream error for %s: %v", q.Name, err)
        // 降级到 DoH
        upstream, err = p.doh.Query(data)
        if err != nil {
            log.Printf("DoH error for %s: %v", q.Name, err)
            return nil
        }
    }

    // 提取最小 TTL,存入缓存
    minTTL := extractMinTTL(upstream)
    if minTTL > 0 {
        p.cache.Set(key, upstream, minTTL)
    }

    return upstream
}

// extractMinTTL 从 DNS 响应中提取所有资源记录的最小 TTL
func extractMinTTL(response []byte) uint32 {
    if len(response) < 12 {
        return 0
    }

    anCount := int(binary.BigEndian.Uint16(response[6:8]))
    if anCount == 0 {
        return 0
    }

    var minTTL uint32 = ^uint32(0) // 初始化为最大值
    offset := 12

    // 跳过 Question 部分
    qdCount := int(binary.BigEndian.Uint16(response[4:6]))
    for i := 0; i < qdCount; i++ {
        _, newOffset, err := parseDomainName(response, offset)
        if err != nil {
            return 0
        }
        offset = newOffset + 4
    }

    // 遍历 Answer 记录
    for i := 0; i < anCount; i++ {
        _, newOffset, err := parseDomainName(response, offset)
        if err != nil {
            return 0
        }
        offset = newOffset + 4 // 跳过 TYPE 和 CLASS

        if offset+6 > len(response) {
            return 0
        }

        ttl := binary.BigEndian.Uint32(response[offset : offset+4])
        if ttl < minTTL {
            minTTL = ttl
        }

        rdLength := int(binary.BigEndian.Uint16(response[offset+4 : offset+6]))
        offset += 6 + rdLength
    }

    if minTTL == ^uint32(0) {
        return 0
    }
    return minTTL
}

func main() {
    proxy, err := NewDNSProxy(":5353", 100)
    if err != nil {
        log.Fatalf("Failed to create DNS proxy: %v", err)
    }

    if err := proxy.blocklist.LoadFromFile("blocklist.txt"); err != nil {
        log.Printf("Warning: could not load blocklist: %v", err)
    }

    log.Println("DNS proxy listening on :5353")
    proxy.Start()
}

Level 4 · 进阶与边界

DNS 负载均衡:Round-Robin A 记录

DNS 轮询是最简单的负载均衡方式。权威服务器为同一个域名返回多个 A 记录,每次查询时按不同顺序排列,客户端通常使用第一个 IP:

// 在响应构建时随机打乱 A 记录顺序,实现简单的轮询
func shuffleARecords(response []byte) []byte {
    // 解析所有 Answer 中的 A 记录,随机排列后重新打包
    // 这是一个简化示意,完整实现需要重新序列化报文
    // 生产环境建议使用 miekg/dns 库
    return response
}

更成熟的 DNS 负载均衡会结合地理位置路由(GeoDNS)——根据客户端 IP 的地理位置返回最近的服务器 IP,这需要 IP 地理位置数据库(如 MaxMind GeoIP2)。

miekg/dns 库深入

手工解析 DNS 报文虽然能让你深入理解协议细节,但生产环境应该使用 github.com/miekg/dns 这个成熟的库,它处理了所有边界情况:

import "github.com/miekg/dns"

// 使用 miekg/dns 实现相同的代理功能
func handleWithMiekg(w dns.ResponseWriter, r *dns.Msg) {
    // 检查拦截列表
    if len(r.Question) > 0 && isBlocked(r.Question[0].Name) {
        m := new(dns.Msg)
        m.SetRcode(r, dns.RcodeNameError) // NXDOMAIN
        w.WriteMsg(m)
        return
    }

    // 向上游转发
    c := new(dns.Client)
    resp, _, err := c.Exchange(r, "8.8.8.8:53")
    if err != nil {
        dns.HandleFailed(w, r)
        return
    }

    w.WriteMsg(resp)
}

func main() {
    dns.HandleFunc(".", handleWithMiekg)
    server := &dns.Server{Addr: ":5353", Net: "udp"}
    log.Fatal(server.ListenAndServe())
}

miekg/dnsdns.Server 内部使用了 goroutine-per-request 模型,但限制了最大并发数,并正确处理了 EDNS0、DNSSEC 等复杂特性。

DNSSEC 基础

DNSSEC(DNS Security Extensions)通过数字签名保证 DNS 响应的真实性和完整性,防止 DNS 缓存投毒攻击。它引入了几种新的记录类型:

验证 DNSSEC 的链式信任:从根区域的信任锚(IANA 公布的根密钥)开始,逐级验证每一层的 DS 记录和 DNSKEY 记录,最终验证回答记录的 RRSIG 签名。

DNS 放大攻击防御

DNS 反射放大攻击(Amplification Attack)是一种 DDoS 攻击:攻击者伪造受害者的 IP 发送 DNS 查询,响应发往受害者。DNS 的放大倍数可以很高(某些查询响应/请求比例超过 100:1)。

作为 DNS 服务器开发者,防御措施:

// 1. 限制响应中 ANY 查询(放大攻击常用)
if q.Type == dns.TypeANY {
    // 返回 HINFO 记录代替全量响应(RFC 8482)
    m := new(dns.Msg)
    m.SetReply(r)
    m.Answer = append(m.Answer, &dns.HINFO{
        Hdr: dns.RR_Header{
            Name: q.Name, Rrtype: dns.TypeHINFO,
            Class: dns.ClassINET, Ttl: 0,
        },
        Cpu: "ANY obsoleted",
        Os:  "See RFC 8482",
    })
    w.WriteMsg(m)
    return
}

// 2. Response Rate Limiting (RRL):对同一来源 IP 的响应速率限制
// 3. 开放解析器限制:只响应内部 IP 的递归查询
// 4. 设置 TC 位,强制小响应,让攻击者切换 TCP(无法伪造 IP)

实现简单权威 DNS 服务器

// 权威 DNS 服务器:持有区域记录,直接回答查询
type AuthoritativeServer struct {
    zone map[string][]dns.RR // 键是 "name:type",值是资源记录列表
    mu   sync.RWMutex
}

func (s *AuthoritativeServer) Handle(w dns.ResponseWriter, r *dns.Msg) {
    if len(r.Question) == 0 {
        return
    }

    q := r.Question[0]
    key := fmt.Sprintf("%s:%d", strings.ToLower(q.Name), q.Qtype)

    s.mu.RLock()
    records, ok := s.zone[key]
    s.mu.RUnlock()

    m := new(dns.Msg)
    m.SetReply(r)
    m.Authoritative = true // AA 位置 1,表示权威回答

    if !ok {
        m.SetRcode(r, dns.RcodeNameError) // NXDOMAIN
    } else {
        m.Answer = records
    }

    w.WriteMsg(m)
}

DNS 是互联网最关键的基础设施之一,也是 UDP 编程最经典的应用场景。从手写解析器到 miekg/dns 的高阶封装,从简单代理到权威服务器,每一层都揭示了互联网底层运转的精密机制。理解这些,你才能在网络编程中真正做到"知其然,更知其所以然"。

本章评分
4.8  / 5  (3 评分)

💬 留言讨论