第 38 章

开发自定义 MCP Server

第38章:开发自定义 MCP Server

理解 MCP 协议规范之后,最好的学习方式就是亲手构建一个。本章通过完整实战——开发一个"数据库查询" MCP Server——涵盖从环境搭建到在 Hermes 中注册使用的全过程。你将看到如何定义 Tool Schema、实现工具逻辑、处理错误,以及如何调试一个 MCP Server。


38.1 项目规划:数据库查询 MCP Server

功能设计

我们要开发的 db-query-mcp-server 将提供以下能力:

db-query-mcp-server
├── Tools(工具)
│   ├── query_sql       — 执行只读 SQL 查询
│   ├── list_tables     — 列出所有表名
│   └── describe_table  — 查看表结构
│
├── Resources(资源)
│   ├── db://schema     — 完整数据库 Schema
│   └── db://tables/{name}  — 单表数据预览
│
└── Prompts(提示模板)
    └── analyze_table   — 分析表数据的最优提示词

目录结构

db-query-mcp-server/
├── pyproject.toml
├── README.md
├── src/
│   └── db_query_mcp/
│       ├── __init__.py
│       ├── server.py       # MCP Server 主入口
│       ├── tools.py        # Tool 实现
│       ├── resources.py    # Resource 实现
│       ├── prompts.py      # Prompt 定义
│       ├── db.py           # 数据库连接层
│       └── schemas.py      # 输入验证 Schema
└── tests/
    ├── test_tools.py
    └── test_server.py

38.2 MCP SDK 安装与初始化

安装依赖

# 创建项目
mkdir db-query-mcp-server && cd db-query-mcp-server
python -m venv .venv && source .venv/bin/activate

# 安装 MCP SDK(Python 官方 SDK)
pip install mcp

# 安装其他依赖
pip install \
    aiosqlite \       # 异步 SQLite(开发测试用)
    asyncpg \         # 异步 PostgreSQL
    pydantic \        # 输入验证
    structlog         # 结构化日志

pyproject.toml 配置

[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"

[project]
name = "db-query-mcp-server"
version = "0.1.0"
description = "MCP Server for safe, read-only database queries"
requires-python = ">=3.10"
dependencies = [
    "mcp>=0.9.0",
    "aiosqlite>=0.20.0",
    "pydantic>=2.0",
    "structlog>=24.0",
]

[project.scripts]
# 命令行入口(stdio 模式启动)
db-query-mcp = "db_query_mcp.server:main"

38.3 定义 Tool 的 Schema

schemas.py — 输入验证

# src/db_query_mcp/schemas.py
from pydantic import BaseModel, Field, field_validator
from typing import Optional, List
import re

class QuerySQLInput(BaseModel):
    """执行 SQL 查询的输入参数"""
    sql: str = Field(
        description="要执行的 SQL 查询语句(仅支持 SELECT)",
        min_length=1,
        max_length=10000
    )
    params: Optional[List] = Field(
        default=None,
        description="SQL 参数(防止 SQL 注入)"
    )
    limit: int = Field(
        default=100,
        ge=1,
        le=10000,
        description="最大返回行数(1-10000)"
    )
    
    @field_validator("sql")
    @classmethod
    def validate_readonly(cls, v: str) -> str:
        """验证 SQL 为只读操作"""
        normalized = v.strip().upper()
        
        # 只允许 SELECT 语句
        if not normalized.startswith("SELECT"):
            raise ValueError(
                "安全限制:仅允许 SELECT 查询。"
                f"检测到的语句类型: {normalized.split()[0]}"
            )
        
        # 拒绝危险关键词(防御 SQL 注入变种)
        dangerous_keywords = [
            "DROP", "DELETE", "UPDATE", "INSERT",
            "ALTER", "TRUNCATE", "EXEC", "EXECUTE",
            "xp_cmdshell", "sp_executesql"
        ]
        for kw in dangerous_keywords:
            if re.search(r'\b' + kw + r'\b', normalized):
                raise ValueError(f"安全限制:SQL 中包含禁止的关键词: {kw}")
        
        return v


class ListTablesInput(BaseModel):
    """列出表的输入参数"""
    schema_name: Optional[str] = Field(
        default=None,
        description="数据库 Schema 名称(可选,默认为当前 Schema)"
    )
    include_views: bool = Field(
        default=False,
        description="是否包含视图"
    )


class DescribeTableInput(BaseModel):
    """描述表结构的输入参数"""
    table_name: str = Field(
        description="表名",
        min_length=1,
        max_length=255,
        pattern=r'^[a-zA-Z_][a-zA-Z0-9_]*$'
    )
    include_indexes: bool = Field(
        default=True,
        description="是否包含索引信息"
    )

38.4 数据库连接层

# src/db_query_mcp/db.py
import asyncio
import aiosqlite
from typing import Any, Dict, List, Optional
from contextlib import asynccontextmanager
import structlog

logger = structlog.get_logger()

class DatabaseConnection:
    """数据库连接管理器(支持连接池)"""
    
    def __init__(self, connection_string: str, pool_size: int = 5):
        self.connection_string = connection_string
        self.pool_size = pool_size
        self._pool: List[aiosqlite.Connection] = []
        self._semaphore = asyncio.Semaphore(pool_size)
        self._is_initialized = False
    
    async def initialize(self):
        """初始化连接池"""
        if self._is_initialized:
            return
        
        # SQLite 演示版(生产请换 asyncpg for PostgreSQL)
        for _ in range(self.pool_size):
            conn = await aiosqlite.connect(self.connection_string)
            conn.row_factory = aiosqlite.Row  # 返回字典形式
            await conn.execute("PRAGMA journal_mode=WAL")
            await conn.execute("PRAGMA foreign_keys=ON")
            self._pool.append(conn)
        
        self._is_initialized = True
        logger.info("数据库连接池已初始化", pool_size=self.pool_size)
    
    @asynccontextmanager
    async def acquire(self):
        """从连接池获取连接"""
        async with self._semaphore:
            conn = self._pool[0]  # 简化版,生产用轮询
            try:
                yield conn
            except Exception as e:
                logger.error("数据库操作失败", error=str(e))
                raise
    
    async def query(
        self,
        sql: str,
        params: Optional[List] = None,
        limit: int = 100
    ) -> Dict[str, Any]:
        """
        执行查询并返回结果
        
        Returns:
            {
                "rows": [...],
                "columns": [...],
                "row_count": int,
                "truncated": bool
            }
        """
        async with self.acquire() as conn:
            # 添加 LIMIT 保护(即使 SQL 里已有 LIMIT 也不影响)
            limited_sql = f"SELECT * FROM ({sql}) AS __query LIMIT {limit + 1}"
            
            try:
                async with conn.execute(limited_sql, params or []) as cursor:
                    rows = await cursor.fetchall()
                    columns = [desc[0] for desc in cursor.description]
                    
                    truncated = len(rows) > limit
                    if truncated:
                        rows = rows[:limit]
                    
                    return {
                        "rows": [dict(row) for row in rows],
                        "columns": columns,
                        "row_count": len(rows),
                        "truncated": truncated
                    }
            except aiosqlite.Error as e:
                raise DatabaseQueryError(f"SQL 执行错误: {e}") from e
    
    async def get_tables(self, include_views: bool = False) -> List[str]:
        """获取所有表名"""
        type_filter = "IN ('table', 'view')" if include_views else "= 'table'"
        async with self.acquire() as conn:
            async with conn.execute(
                f"SELECT name FROM sqlite_master WHERE type {type_filter} ORDER BY name"
            ) as cursor:
                return [row[0] for row in await cursor.fetchall()]
    
    async def get_table_schema(self, table_name: str) -> List[Dict]:
        """获取表结构"""
        async with self.acquire() as conn:
            async with conn.execute(
                f"PRAGMA table_info({table_name})"
            ) as cursor:
                columns = await cursor.fetchall()
                return [
                    {
                        "name": col["name"],
                        "type": col["type"],
                        "nullable": not col["notnull"],
                        "default": col["dflt_value"],
                        "primary_key": bool(col["pk"])
                    }
                    for col in columns
                ]
    
    async def close(self):
        for conn in self._pool:
            await conn.close()
        logger.info("数据库连接池已关闭")


class DatabaseQueryError(Exception):
    pass

38.5 实现 handle_call_tool

tools.py — 工具实现

# src/db_query_mcp/tools.py
import json
from typing import Any, Dict, List
from mcp.types import Tool, TextContent, CallToolResult
from .db import DatabaseConnection
from .schemas import QuerySQLInput, ListTablesInput, DescribeTableInput
import structlog

logger = structlog.get_logger()

def get_tool_definitions() -> List[Tool]:
    """返回所有 Tool 的定义(Schema)"""
    return [
        Tool(
            name="query_sql",
            description=(
                "在数据库中执行只读 SQL 查询。"
                "仅支持 SELECT 语句,自动限制返回行数,防止大数据量问题。"
                "使用参数化查询以防止 SQL 注入。"
            ),
            inputSchema={
                "type": "object",
                "properties": {
                    "sql": {
                        "type": "string",
                        "description": "要执行的 SELECT SQL 语句"
                    },
                    "params": {
                        "type": "array",
                        "items": {},
                        "description": "SQL 绑定参数列表(如 [1, 'active'])"
                    },
                    "limit": {
                        "type": "integer",
                        "default": 100,
                        "minimum": 1,
                        "maximum": 10000,
                        "description": "最大返回行数"
                    }
                },
                "required": ["sql"]
            }
        ),
        Tool(
            name="list_tables",
            description="列出数据库中所有表(和可选的视图)的名称",
            inputSchema={
                "type": "object",
                "properties": {
                    "include_views": {
                        "type": "boolean",
                        "default": False,
                        "description": "是否在结果中包含视图"
                    }
                }
            }
        ),
        Tool(
            name="describe_table",
            description=(
                "查看指定表的结构,包括列名、数据类型、是否允许 NULL、"
                "默认值和主键信息。"
            ),
            inputSchema={
                "type": "object",
                "properties": {
                    "table_name": {
                        "type": "string",
                        "description": "要查看的表名(只允许字母、数字和下划线)"
                    }
                },
                "required": ["table_name"]
            }
        )
    ]


async def handle_call_tool(
    name: str,
    arguments: Dict[str, Any],
    db: DatabaseConnection
) -> CallToolResult:
    """
    处理所有工具调用的中央分发函数
    
    Args:
        name: 工具名称
        arguments: 工具参数
        db: 数据库连接
    
    Returns:
        CallToolResult: 包含 content 列表和 isError 标志
    """
    logger.info("工具调用", tool=name, args=arguments)
    
    try:
        if name == "query_sql":
            return await _handle_query_sql(arguments, db)
        
        elif name == "list_tables":
            return await _handle_list_tables(arguments, db)
        
        elif name == "describe_table":
            return await _handle_describe_table(arguments, db)
        
        else:
            return CallToolResult(
                content=[TextContent(type="text", text=f"未知工具: {name}")],
                isError=True
            )
    
    except ValueError as e:
        # 输入验证错误
        logger.warning("工具输入验证失败", tool=name, error=str(e))
        return CallToolResult(
            content=[TextContent(
                type="text",
                text=f"输入验证错误: {e}"
            )],
            isError=True
        )
    
    except Exception as e:
        # 执行错误
        logger.error("工具执行失败", tool=name, error=str(e), exc_info=True)
        return CallToolResult(
            content=[TextContent(
                type="text",
                text=f"执行错误: {e}"
            )],
            isError=True
        )


async def _handle_query_sql(
    arguments: Dict,
    db: DatabaseConnection
) -> CallToolResult:
    # 验证输入
    params = QuerySQLInput(**arguments)
    
    # 执行查询
    result = await db.query(params.sql, params.params, params.limit)
    
    # 格式化输出
    rows = result["rows"]
    columns = result["columns"]
    
    if not rows:
        return CallToolResult(
            content=[TextContent(type="text", text="查询返回 0 条结果。")],
            isError=False
        )
    
    # 构建 Markdown 表格
    table_lines = [
        "| " + " | ".join(str(c) for c in columns) + " |",
        "| " + " | ".join("---" for _ in columns) + " |",
    ]
    for row in rows:
        table_lines.append(
            "| " + " | ".join(str(row.get(c, "")) for c in columns) + " |"
        )
    
    summary = (
        f"查询成功,返回 {result['row_count']} 条记录"
        + ("(结果已截断至前 {limit} 条)".format(limit=result['row_count'])
           if result["truncated"] else "")
        + ":\n\n"
        + "\n".join(table_lines)
    )
    
    # 同时返回原始 JSON 供程序化处理
    return CallToolResult(
        content=[
            TextContent(type="text", text=summary),
            TextContent(
                type="text",
                text="**原始 JSON 数据:**\n```json\n" 
                     + json.dumps(rows[:10], ensure_ascii=False, indent=2)
                     + "\n```"
            )
        ],
        isError=False
    )


async def _handle_list_tables(
    arguments: Dict,
    db: DatabaseConnection
) -> CallToolResult:
    params = ListTablesInput(**arguments)
    tables = await db.get_tables(params.include_views)
    
    if not tables:
        return CallToolResult(
            content=[TextContent(type="text", text="数据库中没有任何表。")],
            isError=False
        )
    
    table_list = "\n".join(f"- `{t}`" for t in tables)
    return CallToolResult(
        content=[TextContent(
            type="text",
            text=f"数据库包含 {len(tables)} 个表:\n\n{table_list}"
        )],
        isError=False
    )


async def _handle_describe_table(
    arguments: Dict,
    db: DatabaseConnection
) -> CallToolResult:
    params = DescribeTableInput(**arguments)
    schema = await db.get_table_schema(params.table_name)
    
    if not schema:
        return CallToolResult(
            content=[TextContent(
                type="text",
                text=f"表 `{params.table_name}` 不存在或没有列。"
            )],
            isError=True
        )
    
    lines = [
        f"## 表结构:`{params.table_name}`\n",
        "| 列名 | 类型 | 可空 | 默认值 | 主键 |",
        "| ---- | ---- | ---- | ------ | ---- |",
    ]
    for col in schema:
        lines.append(
            f"| `{col['name']}` | {col['type']} | "
            f"{'是' if col['nullable'] else '否'} | "
            f"{col['default'] or '-'} | "
            f"{'✓' if col['primary_key'] else ''} |"
        )
    
    return CallToolResult(
        content=[TextContent(type="text", text="\n".join(lines))],
        isError=False
    )

38.6 完整 Server 实现

# src/db_query_mcp/server.py
import asyncio
import os
import sys
from typing import Any, Dict
import structlog
from mcp.server import Server
from mcp.server.stdio import stdio_server
from mcp.types import (
    Tool, Resource, Prompt, TextContent,
    GetPromptResult, PromptMessage
)

from .db import DatabaseConnection
from .tools import get_tool_definitions, handle_call_tool
from .resources import get_resource_list, handle_read_resource
from .prompts import get_prompt_list, handle_get_prompt

# 配置结构化日志(输出到 stderr,不干扰 stdout 的 JSON-RPC)
structlog.configure(
    processors=[structlog.dev.ConsoleRenderer()],
    wrapper_class=structlog.stdlib.BoundLogger,
    logger_factory=structlog.WriteLoggerFactory(file=sys.stderr),
)
logger = structlog.get_logger()

# 创建 MCP Server 实例
app = Server("db-query-mcp-server")

# 全局数据库连接
db: DatabaseConnection = None


@app.list_tools()
async def list_tools() -> list[Tool]:
    """返回所有可用工具的定义"""
    return get_tool_definitions()


@app.call_tool()
async def call_tool(name: str, arguments: Dict[str, Any]) -> Any:
    """处理工具调用"""
    return await handle_call_tool(name, arguments, db)


@app.list_resources()
async def list_resources():
    """返回所有可用资源"""
    return await get_resource_list(db)


@app.read_resource()
async def read_resource(uri: str):
    """读取资源内容"""
    return await handle_read_resource(uri, db)


@app.list_prompts()
async def list_prompts():
    """返回所有 Prompt 模板"""
    return get_prompt_list()


@app.get_prompt()
async def get_prompt(name: str, arguments: Dict[str, str] = None):
    """获取指定 Prompt 模板"""
    return await handle_get_prompt(name, arguments or {}, db)


async def run_server():
    """启动 MCP Server(stdio 模式)"""
    global db
    
    # 从环境变量读取配置
    db_path = os.environ.get("DB_PATH", ":memory:")
    pool_size = int(os.environ.get("DB_POOL_SIZE", "5"))
    
    logger.info("启动 db-query-mcp-server", db_path=db_path)
    
    # 初始化数据库连接池
    db = DatabaseConnection(db_path, pool_size)
    await db.initialize()
    
    try:
        # 启动 stdio transport
        async with stdio_server() as (read_stream, write_stream):
            await app.run(
                read_stream,
                write_stream,
                app.create_initialization_options()
            )
    finally:
        await db.close()
        logger.info("db-query-mcp-server 已关闭")


def main():
    """CLI 入口点"""
    asyncio.run(run_server())


if __name__ == "__main__":
    main()

38.7 在 Hermes 中注册使用

方式一:Hermes 配置文件注册

# hermes_config.yaml
mcp_servers:
  - name: "db-query"
    description: "数据库查询工具"
    command: ["db-query-mcp"]   # 使用 pyproject.toml 中定义的入口
    env:
      DB_PATH: "/data/production.db"
      DB_POOL_SIZE: "10"
    timeout: 30
    auto_approve_tools:
      - "list_tables"
      - "describe_table"
    # query_sql 需要用户确认(有潜在风险)
    require_approval_tools:
      - "query_sql"

方式二:编程方式注册

# register_mcp.py
import asyncio
from hermes import Agent
from hermes.mcp import StdioServerConfig

async def main():
    # 创建 Agent
    agent = Agent(model="hermes-pro")
    
    # 注册 MCP Server
    db_server = StdioServerConfig(
        name="db-query",
        command=["db-query-mcp"],
        env={"DB_PATH": "/data/production.db"},
        timeout=30
    )
    
    await agent.mcp.register(db_server)
    
    # 验证 Server 已连接
    tools = await agent.mcp.list_tools("db-query")
    print(f"已注册工具: {[t.name for t in tools]}")
    
    # 使用工具
    result = await agent.run(
        "数据库中有哪些表?分别有多少条记录?"
    )
    print(result.text)


asyncio.run(main())

38.8 调试技巧

方法一:MCP Inspector(官方调试工具)

# 安装 MCP Inspector
npm install -g @modelcontextprotocol/inspector

# 启动调试器连接你的 Server
mcp-inspector db-query-mcp

# Inspector 提供 Web UI(http://localhost:5173)
# 可以:
# - 查看所有 Tool/Resource/Prompt
# - 手动调用 Tool
# - 查看原始 JSON-RPC 消息
# - 查看 Server 日志

方法二:手动 JSON-RPC 测试

# 启动 Server 并手动发送 JSON-RPC 消息
echo '{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"test","version":"0.1"}}}' | db-query-mcp

方法三:Python 集成测试

# tests/test_server.py
import asyncio
import pytest
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client

@pytest.fixture
async def mcp_session(tmp_path):
    """创建测试数据库并启动 MCP Server"""
    import aiosqlite
    
    # 创建测试数据库
    db_path = str(tmp_path / "test.db")
    async with aiosqlite.connect(db_path) as conn:
        await conn.execute("""
            CREATE TABLE users (
                id INTEGER PRIMARY KEY,
                name TEXT NOT NULL,
                email TEXT UNIQUE,
                created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
            )
        """)
        await conn.execute(
            "INSERT INTO users (name, email) VALUES (?, ?)",
            ("Alice", "[email protected]")
        )
        await conn.commit()
    
    # 启动 MCP Server
    server_params = StdioServerParameters(
        command="db-query-mcp",
        env={"DB_PATH": db_path}
    )
    
    async with stdio_client(server_params) as (read, write):
        async with ClientSession(read, write) as session:
            await session.initialize()
            yield session


@pytest.mark.asyncio
async def test_list_tools(mcp_session):
    tools = await mcp_session.list_tools()
    tool_names = [t.name for t in tools.tools]
    assert "query_sql" in tool_names
    assert "list_tables" in tool_names
    assert "describe_table" in tool_names


@pytest.mark.asyncio
async def test_list_tables(mcp_session):
    result = await mcp_session.call_tool("list_tables", {})
    assert not result.isError
    assert "users" in result.content[0].text


@pytest.mark.asyncio
async def test_query_sql_safe(mcp_session):
    result = await mcp_session.call_tool(
        "query_sql",
        {"sql": "SELECT * FROM users", "limit": 10}
    )
    assert not result.isError
    assert "Alice" in result.content[0].text


@pytest.mark.asyncio
async def test_query_sql_blocked_write(mcp_session):
    """验证写操作被正确拒绝"""
    result = await mcp_session.call_tool(
        "query_sql",
        {"sql": "DELETE FROM users WHERE 1=1"}
    )
    assert result.isError
    assert "安全限制" in result.content[0].text

本章小结

本章通过完整实战构建了一个生产级 MCP Server:

  1. 项目规划:确定 Tool/Resource/Prompt 三种能力的边界和设计
  2. SDK 使用:MCP Python SDK 提供了 Serverstdio_server 等核心组件,大幅降低开发成本
  3. Schema 定义:用 Pydantic 验证输入,用 JSON Schema 描述工具接口
  4. 安全的工具实现:只允许 SELECT、参数化查询防注入、行数限制防大数据
  5. Server 主文件:装饰器驱动的路由(@app.call_tool()),清晰分离关注点
  6. 注册使用:配置文件和编程两种注册方式,支持不同场景
  7. 调试技巧:MCP Inspector、JSON-RPC 手动测试、pytest 集成测试三板斧

思考题

  1. 当前实现只支持 SQLite。如何修改 DatabaseConnection 以支持 PostgreSQL?需要考虑哪些差异(连接字符串、RETURNING 子句、类型映射)?
  2. 如何为 query_sql 添加执行时间限制(如最长 10 秒),超时后自动取消查询?
  3. MCP Server 是无状态的吗?如果用户在一次对话中执行多个查询,Server 是否需要维护"查询历史"状态?
  4. 如何将这个 Server 从 stdio 模式改为 SSE 模式,使其能够作为独立的 HTTP 服务运行?
本章评分
4.6  / 5  (3 评分)

💬 留言讨论