DNS 服务器:UDP 网络编程
第三十二章:DNS 服务器:UDP 网络编程
你每天使用的互联网,每一次打开浏览器输入 github.com,背后都发生了一件让人着迷的事情:你的操作系统向某台服务器发送了一个 53 字节左右的 UDP 数据包,几十毫秒后收到回应,然后才建立起 TCP 连接,开始传输网页内容。这个过程叫做 DNS 解析。
DNS(Domain Name System,域名系统)是互联网的电话簿。它将人类可读的域名翻译成机器可读的 IP 地址。没有 DNS,你需要记住 140.82.114.4 才能访问 GitHub。但 DNS 远不止是"查字典"那么简单——它是一个分布式的、层级化的、高度可扩展的系统,每天处理着全球数以万亿计的查询请求。
本章的目标是深入 DNS 的每一个细节,然后用 Go 从零构建一个 DNS 代理服务器。这不仅是学习 UDP 编程的最佳场景,也是理解互联网底层协议设计哲学的窗口。
Level 1 · 你需要知道的
DNS 的工作原理:一次完整的查询旅程
当你在浏览器输入 www.google.com 时,发生了以下事情:
第一步:检查本地缓存
操作系统首先检查自身的 DNS 缓存。如果最近访问过这个域名,且缓存的 TTL(Time To Live,生存时间)未过期,直接返回缓存结果,查询结束。
第二步:查询递归解析器
缓存未命中,操作系统向配置的 DNS 服务器(通常是你的路由器或 ISP 提供的服务器,或你手动配置的 8.8.8.8)发送查询请求。这台服务器叫做递归解析器(Recursive Resolver)。它代替你完成整个查询过程。
第三步:递归解析器询问根服务器
递归解析器也可能没有缓存。它向 DNS 根服务器发送查询。全球有 13 个根服务器 IP(A 到 M),但每个 IP 背后都有数百台实体机器提供 Anycast 服务。根服务器不知道 www.google.com 的 IP,但它知道谁负责 .com 域——它返回 .com 顶级域(TLD)的权威服务器地址。
第四步:询问 TLD 服务器
递归解析器向 .com TLD 服务器发送查询。TLD 服务器不知道 www.google.com 的具体 IP,但它知道 google.com 的权威服务器是哪些,返回这些服务器的地址。
第五步:询问权威服务器
递归解析器向 google.com 的权威服务器发送查询。权威服务器(Authoritative Server)是最终答案的来源——它持有该域名的 DNS 记录,直接返回 www.google.com 对应的 IP 地址。
第六步:结果逐级返回并缓存
答案沿路径返回,每一层都按照 TTL 值缓存结果,最终到达你的浏览器。整个过程通常在 100ms 以内完成。
递归解析器 vs 权威服务器
这是两个经常被混淆的概念:
| 类型 | 作用 | 例子 |
|---|---|---|
| 递归解析器 | 代替客户端完成递归查询,缓存结果 | 8.8.8.8(Google)、1.1.1.1(Cloudflare)、你的路由器 |
| 权威服务器 | 持有域名最终记录,直接回答 | Cloudflare DNS、AWS Route 53、各大域名注册商 |
递归解析器是"勤劳的中间人",它做实际的查询工作,但不持有原始记录。权威服务器是"最终权威",它只回答自己管辖域的查询,不做递归。
为什么 DNS 用 UDP?
DNS 主要使用 UDP 协议,端口 53。为什么不用 TCP?
UDP 的优势:
- 低延迟:无需三次握手,请求发出即等待回应
- 轻量:DNS 查询通常很小(几十字节),一个 UDP 包就够了
- 无连接:服务器无需维护连接状态,可以极高效率处理海量并发查询
TCP 的使用场景:
- DNS 响应超过 512 字节(EDNS0 扩展后是 4096 字节)时,截断响应后客户端会切换到 TCP 重试
- DNS 区域传输(AXFR/IXFR)——主服务器向从服务器传输完整 DNS 记录,数据量大,必须用 TCP
- DNS over TLS(DoT)和 DNS over HTTPS(DoH)——加密协议,基于 TCP
这就是为什么每一个 Go 开发者都应该理解 DNS:它是你写的每一个网络程序的基础设施。当你的服务在生产环境出现 DNS 解析超时,当你需要实现服务发现,当你要做广告拦截器或安全代理,都需要深入理解 DNS 的工作原理。
Level 2 · 原理深入
DNS 报文格式
DNS 报文格式由 RFC 1035 定义,是一个精心设计的紧凑二进制格式。理解它是实现 DNS 相关功能的基础。
一个 DNS 报文由五个部分组成:
+---------------------+
| Header | 12 字节固定长度
+---------------------+
| Question | 查询的域名和类型
+---------------------+
| Answer | 回答记录
+---------------------+
| Authority | 权威服务器记录
+---------------------+
| Additional | 附加记录
+---------------------+
Header 结构(12 字节):
0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
| ID | 2字节:事务ID,随机生成,用于匹配请求和响应
+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
|QR| Opcode |AA|TC|RD|RA| Z | RCODE | 2字节:标志位
+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
| QDCOUNT | 2字节:问题数量
+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
| ANCOUNT | 2字节:回答记录数量
+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
| NSCOUNT | 2字节:权威记录数量
+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
| ARCOUNT | 2字节:附加记录数量
+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
关键标志位含义:
- QR:0 = 查询,1 = 响应
- AA(Authoritative Answer):响应来自权威服务器时置 1
- TC(TrunCation):响应被截断时置 1,客户端应切换 TCP 重试
- RD(Recursion Desired):客户端请求递归查询
- RA(Recursion Available):服务器支持递归查询
- RCODE:响应码,0 = 成功,3 = NXDOMAIN(域名不存在)
域名压缩编码
DNS 中域名的编码方式非常有趣。www.google.com 被编码为:
3 w w w 6 g o o g l e 3 c o m 0
每个标签(label)前面跟着一个长度字节,最后以 0x00 结尾。这是"长度前缀"编码。
但 DNS 还有一个节省空间的技巧——指针压缩。当报文中同一个域名出现多次时,第二次可以用一个两字节指针代替,指向第一次出现的位置。指针的特征是高两位为 11,即 0xC0 开头:
0xC0 0x0C → 指向报文偏移 12 处的域名
实现 DNS 解析器时,必须正确处理这个递归指针,否则解析会出错甚至造成无限循环(恶意报文可构造循环指针)。
net.UDPConn:低级 UDP 编程
Go 的 net 包提供了两层 UDP 编程接口:
高层接口:net.Dial("udp", ...) 返回 net.Conn,适合客户端场景,隐藏了远端地址的细节。
低层接口:net.ListenUDP(...) 返回 *net.UDPConn,适合服务器场景,每次接收数据时都能获得发送方的地址。
DNS 服务器必须用低层接口,因为它需要记住每个查询来自哪个客户端地址,然后将响应发回该地址:
// DNS 服务器的核心:ListenUDP
conn, err := net.ListenUDP("udp", &net.UDPAddr{
IP: net.ParseIP("0.0.0.0"),
Port: 53,
})
// ReadFromUDP 同时返回数据和发送方地址
n, clientAddr, err := conn.ReadFromUDP(buf)
// WriteToUDP 向指定地址发送响应
_, err = conn.WriteToUDP(response, clientAddr)
重要的性能陷阱:UDP 是无连接的,每次 ReadFromUDP 都是一次独立的系统调用,获取一个独立的数据报。与 TCP 流不同,UDP 保证消息边界——你发送 100 字节,对方就收到 100 字节的完整包,不会粘包也不会分片(在 MTU 限制范围内)。
Goroutine-per-Request vs Worker Pool
对于 DNS 服务器,处理每个查询的并发模型选择至关重要:
Goroutine-per-Request 模型:
ReadFromUDP → 为每个查询启动 go func() → 处理 → WriteToUDP
优点:简单,代码直观。缺点:在 DDoS 攻击场景下,可能创建数百万个 goroutine,内存耗尽。每个 goroutine 初始栈 2-8KB,100 万 goroutine 就是 2-8GB 内存。
Worker Pool 模型:
ReadFromUDP → channel → 固定数量 workers → 处理 → WriteToUDP
优点:内存可控,可以做背压(back pressure)。缺点:worker 数量选择影响吞吐量。
对于生产级 DNS 服务器,Worker Pool 是更好的选择,但 worker 数量应该根据 CPU 核数和上游 DNS 延迟来调整。如果上游 DNS 平均延迟 20ms,CPU 8 核,那么合理的 worker 数量大约是 8 * (1000/20) = 400,以保持 CPU 充分利用。
DNS 缓存与 TTL 管理
缓存是 DNS 性能的关键。TTL 是每条 DNS 记录携带的"保鲜期",单位是秒。实现缓存时有几个重要细节:
TTL 的递减问题:当你缓存一条 TTL 为 300 秒的记录时,缓存 150 秒后向客户端返回时,应该返回剩余的 150 秒 TTL,而不是原始的 300 秒。否则客户端可能在记录已经过期后还继续使用。
负缓存(Negative Caching):NXDOMAIN 响应也应该缓存,使用 SOA 记录中的 minimum TTL 值。否则对不存在域名的大量查询会反复打到上游服务器。
并发安全:DNS 缓存会被多个 worker goroutine 并发读写,必须用 sync.RWMutex 或 sync.Map 保护。
Level 3 · 代码实战
构建 DNS 代理/转发器
下面我们构建一个完整的 DNS 代理服务器,支持缓存、上游转发和本地拦截列表。
项目结构:
dnsproxy/
├── main.go
├── server.go # UDP 服务器和主循环
├── resolver.go # DNS 查询解析
├── cache.go # TTL 感知缓存
├── upstream.go # 上游 DNS 转发
├── blocklist.go # 广告拦截
└── doh.go # DNS over HTTPS 上游
DNS 报文解析(resolver.go):
package main
import (
"encoding/binary"
"errors"
"fmt"
"strings"
)
// DNSHeader 表示 DNS 报文头部
type DNSHeader struct {
ID uint16
Flags uint16
QDCount uint16
ANCount uint16
NSCount uint16
ARCount uint16
}
// DNSQuestion 表示一个查询问题
type DNSQuestion struct {
Name string
Type uint16
Class uint16
}
// DNSResourceRecord 表示一条资源记录
type DNSResourceRecord struct {
Name string
Type uint16
Class uint16
TTL uint32
RDLength uint16
RData []byte
}
// DNSMessage 表示完整的 DNS 报文
type DNSMessage struct {
Header DNSHeader
Questions []DNSQuestion
Answers []DNSResourceRecord
Authorities []DNSResourceRecord
Additionals []DNSResourceRecord
Raw []byte // 原始字节,用于转发时直接传递
}
// 从字节流解析 DNS 报文头部
func parseHeader(data []byte) (DNSHeader, error) {
if len(data) < 12 {
return DNSHeader{}, errors.New("DNS message too short")
}
return DNSHeader{
ID: binary.BigEndian.Uint16(data[0:2]),
Flags: binary.BigEndian.Uint16(data[2:4]),
QDCount: binary.BigEndian.Uint16(data[4:6]),
ANCount: binary.BigEndian.Uint16(data[6:8]),
NSCount: binary.BigEndian.Uint16(data[8:10]),
ARCount: binary.BigEndian.Uint16(data[10:12]),
}, nil
}
// 解析 DNS 域名,处理指针压缩
// offset 是当前解析位置,返回 (域名, 下一个字段的偏移量)
func parseDomainName(data []byte, offset int) (string, int, error) {
var labels []string
originalOffset := -1 // 遇到指针后保存原始位置
for {
if offset >= len(data) {
return "", 0, errors.New("offset out of bounds")
}
length := int(data[offset])
if length == 0 {
// 根域名终止符
offset++
break
}
if length&0xC0 == 0xC0 {
// 指针压缩:高两位为 11
if offset+1 >= len(data) {
return "", 0, errors.New("invalid pointer")
}
// 计算指针目标:低 14 位为偏移量
ptr := int(binary.BigEndian.Uint16(data[offset:offset+2]) & 0x3FFF)
if originalOffset == -1 {
// 第一次遇到指针,记录当前位置以便后续继续解析
originalOffset = offset + 2
}
// 防止恶意构造的循环指针
if ptr >= offset {
return "", 0, fmt.Errorf("forward pointer not allowed: %d >= %d", ptr, offset)
}
offset = ptr
continue
}
if length&0xC0 != 0 {
return "", 0, fmt.Errorf("invalid label length: %d", length)
}
// 普通标签
offset++
if offset+length > len(data) {
return "", 0, errors.New("label exceeds data length")
}
labels = append(labels, string(data[offset:offset+length]))
offset += length
}
if originalOffset != -1 {
offset = originalOffset
}
return strings.Join(labels, "."), offset, nil
}
// 解析查询问题节
func parseQuestion(data []byte, offset int) (DNSQuestion, int, error) {
name, newOffset, err := parseDomainName(data, offset)
if err != nil {
return DNSQuestion{}, 0, err
}
if newOffset+4 > len(data) {
return DNSQuestion{}, 0, errors.New("question section too short")
}
return DNSQuestion{
Name: name,
Type: binary.BigEndian.Uint16(data[newOffset : newOffset+2]),
Class: binary.BigEndian.Uint16(data[newOffset+2 : newOffset+4]),
}, newOffset + 4, nil
}
// 解析完整 DNS 报文
func parseMessage(data []byte) (*DNSMessage, error) {
header, err := parseHeader(data)
if err != nil {
return nil, err
}
msg := &DNSMessage{
Header: header,
Raw: data,
}
offset := 12
for i := 0; i < int(header.QDCount); i++ {
q, newOffset, err := parseQuestion(data, offset)
if err != nil {
return nil, fmt.Errorf("parsing question %d: %w", i, err)
}
msg.Questions = append(msg.Questions, q)
offset = newOffset
}
// Answer/Authority/Additional 记录解析类似,省略重复代码
return msg, nil
}
// 构建域名查询的缓存键
func cacheKey(name string, qtype uint16) string {
return fmt.Sprintf("%s:%d", strings.ToLower(name), qtype)
}
// QueryIsRecursionDesired 检查 RD 标志
func (h DNSHeader) IsRecursionDesired() bool {
return h.Flags&0x0100 != 0
}
TTL 感知缓存(cache.go):
package main
import (
"sync"
"time"
)
// CacheEntry 存储缓存的 DNS 响应和过期时间
type CacheEntry struct {
Response []byte // 原始 DNS 响应字节
ExpiresAt time.Time // 基于原始 TTL 计算的过期时间
OriginalTTL uint32 // 响应中最小的 TTL 值
}
// DNSCache 是线程安全的 DNS 响应缓存
type DNSCache struct {
mu sync.RWMutex
entries map[string]*CacheEntry
}
func NewDNSCache() *DNSCache {
c := &DNSCache{
entries: make(map[string]*CacheEntry),
}
// 启动后台清理 goroutine
go c.cleanup()
return c
}
// Get 返回缓存响应,并调整报文中的 TTL 值为剩余 TTL
func (c *DNSCache) Get(key string) ([]byte, bool) {
c.mu.RLock()
entry, ok := c.entries[key]
c.mu.RUnlock()
if !ok {
return nil, false
}
remaining := time.Until(entry.ExpiresAt)
if remaining <= 0 {
c.mu.Lock()
delete(c.entries, key)
c.mu.Unlock()
return nil, false
}
// 深拷贝响应,然后更新所有 TTL 字段为剩余时间
response := make([]byte, len(entry.Response))
copy(response, entry.Response)
remainingSeconds := uint32(remaining.Seconds())
updateTTLInResponse(response, remainingSeconds)
return response, true
}
// updateTTLInResponse 修改响应报文中的所有资源记录 TTL 为剩余值
// 注意:这需要跳过 Header 和 Question 部分,定位到 Answer 记录
func updateTTLInResponse(response []byte, ttl uint32) {
if len(response) < 12 {
return
}
anCount := int(binary.BigEndian.Uint16(response[6:8]))
nsCount := int(binary.BigEndian.Uint16(response[8:10]))
arCount := int(binary.BigEndian.Uint16(response[10:12]))
totalRR := anCount + nsCount + arCount
// 跳过 Header(12字节)和 Question 部分
offset := 12
qdCount := int(binary.BigEndian.Uint16(response[4:6]))
for i := 0; i < qdCount; i++ {
_, newOffset, err := parseDomainName(response, offset)
if err != nil {
return
}
offset = newOffset + 4 // 跳过 QTYPE 和 QCLASS
}
// 更新每条资源记录的 TTL
for i := 0; i < totalRR; i++ {
_, newOffset, err := parseDomainName(response, offset)
if err != nil {
return
}
offset = newOffset + 4 // 跳过 TYPE 和 CLASS
if offset+6 > len(response) {
return
}
// TTL 在 TYPE+CLASS 之后的 4 字节
binary.BigEndian.PutUint32(response[offset:offset+4], ttl)
rdLength := int(binary.BigEndian.Uint16(response[offset+4 : offset+6]))
offset += 6 + rdLength
}
}
// Set 缓存一个 DNS 响应
func (c *DNSCache) Set(key string, response []byte, ttl uint32) {
if ttl == 0 {
return // TTL 为 0 的记录不缓存
}
c.mu.Lock()
c.entries[key] = &CacheEntry{
Response: response,
ExpiresAt: time.Now().Add(time.Duration(ttl) * time.Second),
OriginalTTL: ttl,
}
c.mu.Unlock()
}
// cleanup 定期清理过期缓存条目
func (c *DNSCache) cleanup() {
ticker := time.NewTicker(30 * time.Second)
for range ticker.C {
now := time.Now()
c.mu.Lock()
for k, v := range c.entries {
if now.After(v.ExpiresAt) {
delete(c.entries, k)
}
}
c.mu.Unlock()
}
}
上游 UDP 转发(upstream.go):
package main
import (
"fmt"
"net"
"time"
)
// UpstreamResolver 向上游 DNS 服务器转发查询
type UpstreamResolver struct {
servers []string // 上游 DNS 服务器地址列表,如 ["8.8.8.8:53", "1.1.1.1:53"]
timeout time.Duration
}
func NewUpstreamResolver(servers []string) *UpstreamResolver {
return &UpstreamResolver{
servers: servers,
timeout: 5 * time.Second,
}
}
// Query 向上游转发查询,返回原始响应字节
func (r *UpstreamResolver) Query(query []byte) ([]byte, error) {
var lastErr error
for _, server := range r.servers {
resp, err := r.queryServer(server, query)
if err != nil {
lastErr = err
continue
}
return resp, nil
}
return nil, fmt.Errorf("all upstream servers failed, last error: %w", lastErr)
}
func (r *UpstreamResolver) queryServer(server string, query []byte) ([]byte, error) {
conn, err := net.DialTimeout("udp", server, r.timeout)
if err != nil {
return nil, fmt.Errorf("connecting to %s: %w", server, err)
}
defer conn.Close()
conn.SetDeadline(time.Now().Add(r.timeout))
_, err = conn.Write(query)
if err != nil {
return nil, fmt.Errorf("sending query to %s: %w", server, err)
}
buf := make([]byte, 4096) // 支持 EDNS0 的大响应
n, err := conn.Read(buf)
if err != nil {
return nil, fmt.Errorf("reading response from %s: %w", server, err)
}
return buf[:n], nil
}
DNS over HTTPS 上游(doh.go):
package main
import (
"bytes"
"fmt"
"io"
"net/http"
"time"
)
// DoHResolver 通过 HTTPS 向上游发送 DNS 查询(RFC 8484)
type DoHResolver struct {
url string
client *http.Client
}
func NewDoHResolver(url string) *DoHResolver {
return &DoHResolver{
url: url,
client: &http.Client{
Timeout: 5 * time.Second,
},
}
}
// Query 使用 DNS Wireformat over HTTPS(application/dns-message)发送查询
func (r *DoHResolver) Query(query []byte) ([]byte, error) {
req, err := http.NewRequest("POST", r.url, bytes.NewReader(query))
if err != nil {
return nil, err
}
// RFC 8484 规定的 Content-Type
req.Header.Set("Content-Type", "application/dns-message")
req.Header.Set("Accept", "application/dns-message")
resp, err := r.client.Do(req)
if err != nil {
return nil, fmt.Errorf("DoH request failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("DoH server returned status %d", resp.StatusCode)
}
return io.ReadAll(io.LimitReader(resp.Body, 65535))
}
广告拦截列表(blocklist.go):
package main
import (
"bufio"
"os"
"strings"
"sync"
)
// Blocklist 持有被拦截的域名集合
type Blocklist struct {
mu sync.RWMutex
domains map[string]struct{}
}
func NewBlocklist() *Blocklist {
return &Blocklist{
domains: make(map[string]struct{}),
}
}
// LoadFromFile 从 hosts 格式文件加载拦截列表
// 支持格式:0.0.0.0 ads.example.com 或 # 注释
func (b *Blocklist) LoadFromFile(path string) error {
f, err := os.Open(path)
if err != nil {
return err
}
defer f.Close()
b.mu.Lock()
defer b.mu.Unlock()
scanner := bufio.NewScanner(f)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" || strings.HasPrefix(line, "#") {
continue
}
fields := strings.Fields(line)
if len(fields) >= 2 {
// hosts 格式:IP 域名
domain := strings.ToLower(fields[1])
b.domains[domain] = struct{}{}
}
}
return scanner.Err()
}
// IsBlocked 检查域名是否被拦截,同时检查父域名
func (b *Blocklist) IsBlocked(name string) bool {
name = strings.ToLower(strings.TrimSuffix(name, "."))
b.mu.RLock()
defer b.mu.RUnlock()
// 精确匹配
if _, ok := b.domains[name]; ok {
return true
}
// 检查父域名(支持通配符拦截)
for {
idx := strings.Index(name, ".")
if idx == -1 {
break
}
name = name[idx+1:]
if _, ok := b.domains[name]; ok {
return true
}
}
return false
}
// buildNXDOMAINResponse 构建 NXDOMAIN 响应,用于被拦截的域名
func buildNXDOMAINResponse(query []byte) []byte {
if len(query) < 12 {
return nil
}
response := make([]byte, len(query))
copy(response, query)
// 将 QR 置 1(响应)、设置 RCODE 为 3(NXDOMAIN)
// Flags 在字节 2-3
flags := binary.BigEndian.Uint16(query[2:4])
flags |= 0x8000 // QR = 1
flags |= 0x0003 // RCODE = 3 (NXDOMAIN)
flags &^= 0x0200 // AA = 0
binary.BigEndian.PutUint16(response[2:4], flags)
return response
}
主服务器逻辑(server.go):
package main
import (
"encoding/binary"
"log"
"net"
"sync"
)
// DNSProxy 是完整的 DNS 代理服务器
type DNSProxy struct {
conn *net.UDPConn
cache *DNSCache
upstream *UpstreamResolver
doh *DoHResolver
blocklist *Blocklist
workers int
jobs chan dnsJob
}
type dnsJob struct {
data []byte
clientAddr *net.UDPAddr
}
func NewDNSProxy(listenAddr string, workers int) (*DNSProxy, error) {
addr, err := net.ResolveUDPAddr("udp", listenAddr)
if err != nil {
return nil, err
}
conn, err := net.ListenUDP("udp", addr)
if err != nil {
return nil, err
}
// 增大 socket 接收缓冲区,减少高并发时的丢包
conn.SetReadBuffer(4 * 1024 * 1024)
conn.SetWriteBuffer(4 * 1024 * 1024)
return &DNSProxy{
conn: conn,
cache: NewDNSCache(),
upstream: NewUpstreamResolver([]string{"8.8.8.8:53", "1.1.1.1:53"}),
doh: NewDoHResolver("https://cloudflare-dns.com/dns-query"),
blocklist: NewBlocklist(),
workers: workers,
jobs: make(chan dnsJob, workers*10),
}, nil
}
func (p *DNSProxy) Start() {
// 启动 worker pool
var wg sync.WaitGroup
for i := 0; i < p.workers; i++ {
wg.Add(1)
go func() {
defer wg.Done()
p.worker()
}()
}
// 主循环:接收 UDP 数据报
buf := make([]byte, 4096)
for {
n, clientAddr, err := p.conn.ReadFromUDP(buf)
if err != nil {
log.Printf("ReadFromUDP error: %v", err)
continue
}
// 复制数据,避免下次 ReadFromUDP 覆盖
data := make([]byte, n)
copy(data, buf[:n])
// 非阻塞发送到 jobs channel,防止慢客户端拖慢接收循环
select {
case p.jobs <- dnsJob{data: data, clientAddr: clientAddr}:
default:
log.Printf("Worker pool full, dropping query from %s", clientAddr)
}
}
}
func (p *DNSProxy) worker() {
for job := range p.jobs {
response := p.handleQuery(job.data)
if response != nil {
if _, err := p.conn.WriteToUDP(response, job.clientAddr); err != nil {
log.Printf("WriteToUDP error: %v", err)
}
}
}
}
func (p *DNSProxy) handleQuery(data []byte) []byte {
msg, err := parseMessage(data)
if err != nil || len(msg.Questions) == 0 {
return nil
}
q := msg.Questions[0]
// 检查拦截列表
if p.blocklist.IsBlocked(q.Name) {
log.Printf("Blocked: %s", q.Name)
return buildNXDOMAINResponse(data)
}
key := cacheKey(q.Name, q.Type)
// 检查缓存
if cached, ok := p.cache.Get(key); ok {
// 修正响应中的事务 ID 以匹配本次查询
response := make([]byte, len(cached))
copy(response, cached)
binary.BigEndian.PutUint16(response[0:2], msg.Header.ID)
return response
}
// 向上游转发
upstream, err := p.upstream.Query(data)
if err != nil {
log.Printf("Upstream error for %s: %v", q.Name, err)
// 降级到 DoH
upstream, err = p.doh.Query(data)
if err != nil {
log.Printf("DoH error for %s: %v", q.Name, err)
return nil
}
}
// 提取最小 TTL,存入缓存
minTTL := extractMinTTL(upstream)
if minTTL > 0 {
p.cache.Set(key, upstream, minTTL)
}
return upstream
}
// extractMinTTL 从 DNS 响应中提取所有资源记录的最小 TTL
func extractMinTTL(response []byte) uint32 {
if len(response) < 12 {
return 0
}
anCount := int(binary.BigEndian.Uint16(response[6:8]))
if anCount == 0 {
return 0
}
var minTTL uint32 = ^uint32(0) // 初始化为最大值
offset := 12
// 跳过 Question 部分
qdCount := int(binary.BigEndian.Uint16(response[4:6]))
for i := 0; i < qdCount; i++ {
_, newOffset, err := parseDomainName(response, offset)
if err != nil {
return 0
}
offset = newOffset + 4
}
// 遍历 Answer 记录
for i := 0; i < anCount; i++ {
_, newOffset, err := parseDomainName(response, offset)
if err != nil {
return 0
}
offset = newOffset + 4 // 跳过 TYPE 和 CLASS
if offset+6 > len(response) {
return 0
}
ttl := binary.BigEndian.Uint32(response[offset : offset+4])
if ttl < minTTL {
minTTL = ttl
}
rdLength := int(binary.BigEndian.Uint16(response[offset+4 : offset+6]))
offset += 6 + rdLength
}
if minTTL == ^uint32(0) {
return 0
}
return minTTL
}
func main() {
proxy, err := NewDNSProxy(":5353", 100)
if err != nil {
log.Fatalf("Failed to create DNS proxy: %v", err)
}
if err := proxy.blocklist.LoadFromFile("blocklist.txt"); err != nil {
log.Printf("Warning: could not load blocklist: %v", err)
}
log.Println("DNS proxy listening on :5353")
proxy.Start()
}
Level 4 · 进阶与边界
DNS 负载均衡:Round-Robin A 记录
DNS 轮询是最简单的负载均衡方式。权威服务器为同一个域名返回多个 A 记录,每次查询时按不同顺序排列,客户端通常使用第一个 IP:
// 在响应构建时随机打乱 A 记录顺序,实现简单的轮询
func shuffleARecords(response []byte) []byte {
// 解析所有 Answer 中的 A 记录,随机排列后重新打包
// 这是一个简化示意,完整实现需要重新序列化报文
// 生产环境建议使用 miekg/dns 库
return response
}
更成熟的 DNS 负载均衡会结合地理位置路由(GeoDNS)——根据客户端 IP 的地理位置返回最近的服务器 IP,这需要 IP 地理位置数据库(如 MaxMind GeoIP2)。
miekg/dns 库深入
手工解析 DNS 报文虽然能让你深入理解协议细节,但生产环境应该使用 github.com/miekg/dns 这个成熟的库,它处理了所有边界情况:
import "github.com/miekg/dns"
// 使用 miekg/dns 实现相同的代理功能
func handleWithMiekg(w dns.ResponseWriter, r *dns.Msg) {
// 检查拦截列表
if len(r.Question) > 0 && isBlocked(r.Question[0].Name) {
m := new(dns.Msg)
m.SetRcode(r, dns.RcodeNameError) // NXDOMAIN
w.WriteMsg(m)
return
}
// 向上游转发
c := new(dns.Client)
resp, _, err := c.Exchange(r, "8.8.8.8:53")
if err != nil {
dns.HandleFailed(w, r)
return
}
w.WriteMsg(resp)
}
func main() {
dns.HandleFunc(".", handleWithMiekg)
server := &dns.Server{Addr: ":5353", Net: "udp"}
log.Fatal(server.ListenAndServe())
}
miekg/dns 的 dns.Server 内部使用了 goroutine-per-request 模型,但限制了最大并发数,并正确处理了 EDNS0、DNSSEC 等复杂特性。
DNSSEC 基础
DNSSEC(DNS Security Extensions)通过数字签名保证 DNS 响应的真实性和完整性,防止 DNS 缓存投毒攻击。它引入了几种新的记录类型:
- DNSKEY:区域的公钥
- RRSIG:资源记录集的数字签名
- DS(Delegation Signer):子域密钥的摘要,存储在父区域
- NSEC/NSEC3:证明某个名字不存在(防止枚举的 NSEC3 使用哈希)
验证 DNSSEC 的链式信任:从根区域的信任锚(IANA 公布的根密钥)开始,逐级验证每一层的 DS 记录和 DNSKEY 记录,最终验证回答记录的 RRSIG 签名。
DNS 放大攻击防御
DNS 反射放大攻击(Amplification Attack)是一种 DDoS 攻击:攻击者伪造受害者的 IP 发送 DNS 查询,响应发往受害者。DNS 的放大倍数可以很高(某些查询响应/请求比例超过 100:1)。
作为 DNS 服务器开发者,防御措施:
// 1. 限制响应中 ANY 查询(放大攻击常用)
if q.Type == dns.TypeANY {
// 返回 HINFO 记录代替全量响应(RFC 8482)
m := new(dns.Msg)
m.SetReply(r)
m.Answer = append(m.Answer, &dns.HINFO{
Hdr: dns.RR_Header{
Name: q.Name, Rrtype: dns.TypeHINFO,
Class: dns.ClassINET, Ttl: 0,
},
Cpu: "ANY obsoleted",
Os: "See RFC 8482",
})
w.WriteMsg(m)
return
}
// 2. Response Rate Limiting (RRL):对同一来源 IP 的响应速率限制
// 3. 开放解析器限制:只响应内部 IP 的递归查询
// 4. 设置 TC 位,强制小响应,让攻击者切换 TCP(无法伪造 IP)
实现简单权威 DNS 服务器
// 权威 DNS 服务器:持有区域记录,直接回答查询
type AuthoritativeServer struct {
zone map[string][]dns.RR // 键是 "name:type",值是资源记录列表
mu sync.RWMutex
}
func (s *AuthoritativeServer) Handle(w dns.ResponseWriter, r *dns.Msg) {
if len(r.Question) == 0 {
return
}
q := r.Question[0]
key := fmt.Sprintf("%s:%d", strings.ToLower(q.Name), q.Qtype)
s.mu.RLock()
records, ok := s.zone[key]
s.mu.RUnlock()
m := new(dns.Msg)
m.SetReply(r)
m.Authoritative = true // AA 位置 1,表示权威回答
if !ok {
m.SetRcode(r, dns.RcodeNameError) // NXDOMAIN
} else {
m.Answer = records
}
w.WriteMsg(m)
}
DNS 是互联网最关键的基础设施之一,也是 UDP 编程最经典的应用场景。从手写解析器到 miekg/dns 的高阶封装,从简单代理到权威服务器,每一层都揭示了互联网底层运转的精密机制。理解这些,你才能在网络编程中真正做到"知其然,更知其所以然"。