第 24 章

实战:数据分析助手——NL2SQL、图表解读与报告自动化

第24章:实战——数据分析助手——NL2SQL、图表解读与报告自动化

让业务人员用自然语言直接查询数据库、解读图表、生成报告——本章展示如何用 Dify 把数据分析能力交到每个人手里。

本章导读

"上周各区域的销售额是多少?" "这张图表说明了什么问题?" "帮我写一份季度数据分析报告"——这些请求每天都在向数据分析团队涌来,占用了分析师大量时间,但从业务角度每一个问题都很重要。

数据分析助手的核心是 NL2SQL(Natural Language to SQL):将用户的自然语言问题转换为精确的 SQL 查询,执行查询,并将结果以人类可读的方式呈现。这不是一个新想法,但在 LLM 时代,实现质量和易用性都有了质的飞跃。

案例背景:某零售连锁集团(全国 500 门店),数据团队 8 人,每日处理来自门店管理、市场营销、供应链等部门的数据查询请求约 150 个。核心数据库包含:

目标


Level 1:基础认知(1-3 年经验)

NL2SQL 的工作原理

NL2SQL 不是魔法,它的工作流程很清晰:

用户问题(自然语言)
        ↓
  数据库 Schema 上下文注入
        ↓
  LLM 生成 SQL
        ↓
  SQL 安全验证(防注入)
        ↓
  执行 SQL(只读权限)
        ↓
  结果格式化(表格/自然语言)
        ↓
  返回给用户

为什么 LLM 能生成正确的 SQL?

LLM 在训练数据中见过大量的 SQL,能理解 SQL 语法。但要生成针对你公司数据库的 SQL,它需要知道你的表结构(Schema)。这就是为什么"Schema 注入"是 NL2SQL 的核心。

类比:让 LLM 写 SQL,就像让一个懂 SQL 的程序员查你的数据库。这个程序员很聪明,但他需要先看数据库文档(Schema)才能写出正确的查询。

在 Dify 中配置 NL2SQL 工具

步骤一:准备数据库 Schema 文档

将数据库表结构整理为 LLM 易理解的格式:

# 数据库 Schema 文档

## 数据库:sales_db

### 表:daily_sales(日销售记录)
| 列名 | 类型 | 说明 | 示例 |
|------|------|------|------|
| id | BIGINT | 主键 | 10001 |
| store_id | INT | 门店ID | 42 |
| sale_date | DATE | 销售日期 | 2024-03-15 |
| product_id | INT | 商品ID | 5001 |
| product_name | VARCHAR | 商品名称 | "苹果手机壳" |
| category | VARCHAR | 商品类别 | "手机配件" |
| quantity | INT | 销量 | 5 |
| unit_price | DECIMAL | 单价(元) | 39.9 |
| total_amount | DECIMAL | 销售额(元) | 199.5 |
| promotion_id | INT | 促销活动ID(无促销为NULL)| 3 |

### 表:stores(门店信息)
| 列名 | 类型 | 说明 | 示例 |
|------|------|------|------|
| id | INT | 门店ID | 42 |
| store_name | VARCHAR | 门店名称 | "北京朝阳万达店" |
| city | VARCHAR | 城市 | "北京" |
| province | VARCHAR | 省份 | "北京" |
| region | VARCHAR | 大区(华北/华南/华东...)| "华北" |
| open_date | DATE | 开业日期 | 2020-05-01 |
| store_type | VARCHAR | 门店类型(旗舰店/标准店/小店)| "旗舰店" |
| area_sqm | INT | 面积(平方米) | 500 |

### 表:inventory(库存)
| 列名 | 类型 | 说明 | 示例 |
|------|------|------|------|
| store_id | INT | 门店ID | 42 |
| product_id | INT | 商品ID | 5001 |
| current_stock | INT | 当前库存 | 100 |
| last_updated | TIMESTAMP | 最后更新时间 | 2024-03-15 14:30:00 |

### 常用关联关系
- daily_sales.store_id = stores.id
- daily_sales.product_id = inventory.product_id + inventory.store_id

### 重要业务规则
- 本月 = 当前自然月
- 上周 = 上个完整的自然周(周一到周日)
- 华北大区 = province IN ('北京', '天津', '河北', '山西', '内蒙古')
- 销售额 = SUM(total_amount)
- 毛利率 = (销售额 - 成本) / 销售额(成本数据在 product_cost 表)

步骤二:配置 Dify Workflow

在 Dify 中创建 NL2SQL 工作流:

节点1:Schema 上下文准备
  → 从文件或数据库动态获取相关表的 Schema

节点2:SQL 生成(LLM)
  输入:用户问题 + Schema 上下文
  输出:SQL 查询语句

节点3:SQL 安全检查(Code 节点)
  → 拒绝 INSERT/UPDATE/DELETE/DROP 等危险操作
  → 检查 SQL 注入风险

节点4:SQL 执行(HTTP 节点)
  → 调用查询 API(不直接连数据库)

节点5:结果解读(LLM)
  输入:SQL 查询结果
  输出:自然语言解读

节点6:可视化建议(可选)
  → 根据数据类型建议图表类型

SQL 生成 Prompt(核心)

你是一位专业的数据分析师,精通 MySQL 和业务数据分析。

## 数据库信息:
{{schema_context}}

## 业务规则:
- 所有金额单位为人民币元
- 日期范围:"本月"指当月1日到今天,"上月"指上月全月,"今年"指今年1月1日到今天
- 大区定义:华北(京津冀晋蒙)、华东(沪苏浙皖赣)、华南(粤闽琼桂)、华西(川渝云贵藏)、华中(豫鄂湘)、东北(黑吉辽)、西北(陕甘宁青新)

## 任务:
将用户的问题转换为 MySQL 查询语句。

## 规则:
1. 只生成 SELECT 语句,禁止 INSERT/UPDATE/DELETE/DROP 等操作
2. 对于时间范围,使用精确的日期函数,如:
   - 本月:WHERE sale_date >= DATE_FORMAT(NOW(), '%Y-%m-01') AND sale_date <= NOW()
   - 上周:WHERE YEARWEEK(sale_date) = YEARWEEK(NOW()) - 1
3. 金额字段使用 ROUND(value, 2) 保留两位小数
4. 超过1000行的查询必须添加 LIMIT 限制
5. 如果问题不明确,选择最可能的解释,但在 SQL 注释中说明假设

## 用户问题:
{{user_question}}

## 输出格式:
```sql
-- 查询意图:[简述你理解的查询意图]
-- 假设:[如有假设,列出]
[SQL 语句]

请只输出 SQL,不要其他解释:


---

## Level 2:机制深解(3-5 年经验)

### 完整 NL2SQL 系统实现

```python
# nl2sql_service.py — NL2SQL 服务核心实现

import re
import json
import sqlparse
import mysql.connector
from typing import Optional

class NL2SQLService:
    
    def __init__(self, db_config: dict, dify_config: dict):
        self.db_config = db_config
        self.dify_config = dify_config
        self.schema_manager = SchemaManager(db_config)
    
    def process_query(self, user_question: str, session_id: str = None) -> dict:
        """
        处理用户的自然语言查询
        
        Returns:
            {
                'sql': 生成的 SQL,
                'result': 查询结果,
                'interpretation': 自然语言解读,
                'chart_suggestion': 图表建议,
                'error': 错误信息(如有)
            }
        """
        
        # Step 1: 智能 Schema 选择(不是全量 Schema,避免 Token 浪费)
        relevant_tables = self._identify_relevant_tables(user_question)
        schema_context = self.schema_manager.get_schema_context(relevant_tables)
        
        # Step 2: 调用 Dify 生成 SQL
        sql = self._generate_sql(user_question, schema_context, session_id)
        if not sql:
            return {'error': 'SQL 生成失败'}
        
        # Step 3: 安全检查
        is_safe, safety_error = self._validate_sql_safety(sql)
        if not is_safe:
            return {'error': f'SQL 安全检查未通过:{safety_error}'}
        
        # Step 4: 执行查询
        result, execute_error = self._execute_sql(sql)
        if execute_error:
            # 执行失败时,让 LLM 修正 SQL
            corrected_sql = self._fix_sql(sql, execute_error, schema_context)
            if corrected_sql:
                result, execute_error = self._execute_sql(corrected_sql)
                sql = corrected_sql  # 使用修正后的 SQL
        
        if execute_error:
            return {'sql': sql, 'error': f'查询执行失败:{execute_error}'}
        
        # Step 5: 结果解读
        interpretation = self._interpret_results(user_question, sql, result)
        
        # Step 6: 图表建议
        chart = self._suggest_chart(result)
        
        return {
            'sql': sql,
            'result': result,
            'interpretation': interpretation,
            'chart_suggestion': chart,
            'row_count': len(result.get('rows', [])),
        }
    
    def _identify_relevant_tables(self, question: str) -> list:
        """
        根据问题关键词智能选择相关表
        避免把所有表的 Schema 都塞入 Prompt
        """
        table_keywords = {
            'daily_sales': ['销售', '销量', '营业额', '收入', '流水', 'GMV', '订单'],
            'stores': ['门店', '店铺', '城市', '大区', '地区', '省份'],
            'inventory': ['库存', '库存量', '缺货', '积压'],
            'members': ['会员', '用户', '客户', '消费者'],
            'products': ['商品', '产品', '品类', '品牌', 'SKU'],
        }
        
        relevant_tables = set()
        question_lower = question.lower()
        
        for table, keywords in table_keywords.items():
            if any(kw in question for kw in keywords):
                relevant_tables.add(table)
        
        # 确保至少有基础表
        if not relevant_tables:
            relevant_tables.add('daily_sales')
            relevant_tables.add('stores')
        
        return list(relevant_tables)
    
    def _validate_sql_safety(self, sql: str) -> tuple[bool, Optional[str]]:
        """
        SQL 安全验证
        只允许 SELECT 语句
        """
        # 移除注释
        sql_clean = re.sub(r'--.*$', '', sql, flags=re.MULTILINE)
        sql_clean = re.sub(r'/\*.*?\*/', '', sql_clean, flags=re.DOTALL)
        sql_clean = sql_clean.strip().upper()
        
        # 必须以 SELECT 开头
        if not sql_clean.startswith('SELECT'):
            return False, f"只允许 SELECT 语句,当前语句以 {sql_clean[:20]} 开头"
        
        # 禁止的关键词
        dangerous_keywords = [
            'INSERT', 'UPDATE', 'DELETE', 'DROP', 'CREATE', 'ALTER', 
            'TRUNCATE', 'EXEC', 'EXECUTE', 'CALL', 'GRANT', 'REVOKE',
            'INTO OUTFILE', 'LOAD DATA', 'INFORMATION_SCHEMA'
        ]
        
        for keyword in dangerous_keywords:
            if keyword in sql_clean:
                return False, f"禁止使用关键词:{keyword}"
        
        # 检查是否有 SQL 注入特征
        injection_patterns = [
            r"'\s*OR\s+'",    # ' OR '
            r"'\s*OR\s+1=1",  # ' OR 1=1
            r";\s*--",        # ; --
            r"UNION\s+ALL\s+SELECT",  # UNION ALL SELECT
        ]
        
        for pattern in injection_patterns:
            if re.search(pattern, sql_clean, re.IGNORECASE):
                return False, f"检测到可能的 SQL 注入:{pattern}"
        
        # 检查 LIMIT(超大查询)
        parsed = sqlparse.parse(sql)[0]
        has_limit = any(
            token.ttype is sqlparse.tokens.Keyword and token.value.upper() == 'LIMIT'
            for token in parsed.flatten()
        )
        
        # 如果没有 LIMIT,添加保护性限制
        # 注意:不拒绝,只是告警(在执行层面添加 LIMIT)
        
        return True, None
    
    def _execute_sql(self, sql: str, max_rows: int = 1000) -> tuple:
        """
        执行 SQL 查询(只读连接)
        """
        connection = None
        try:
            # 使用只读账号连接
            connection = mysql.connector.connect(
                host=self.db_config['host'],
                user=self.db_config['readonly_user'],
                password=self.db_config['readonly_password'],
                database=self.db_config['database'],
                connection_timeout=30
            )
            
            cursor = connection.cursor(dictionary=True)
            
            # 自动添加 LIMIT 保护
            sql_with_limit = self._add_limit_if_missing(sql, max_rows)
            
            cursor.execute(sql_with_limit)
            rows = cursor.fetchall()
            
            # 处理 Decimal/datetime 类型
            rows = self._serialize_rows(rows)
            
            columns = [desc[0] for desc in cursor.description] if cursor.description else []
            
            return {
                'columns': columns,
                'rows': rows,
                'total_rows': len(rows),
                'truncated': len(rows) == max_rows
            }, None
        
        except mysql.connector.Error as e:
            return None, str(e)
        
        finally:
            if connection and connection.is_connected():
                connection.close()
    
    def _serialize_rows(self, rows: list) -> list:
        """处理 MySQL 返回的特殊类型"""
        import decimal
        from datetime import date, datetime
        
        serialized = []
        for row in rows:
            clean_row = {}
            for key, value in row.items():
                if isinstance(value, decimal.Decimal):
                    clean_row[key] = float(value)
                elif isinstance(value, (date, datetime)):
                    clean_row[key] = value.isoformat()
                else:
                    clean_row[key] = value
            serialized.append(clean_row)
        return serialized
    
    def _fix_sql(self, failed_sql: str, error_msg: str, schema_context: str) -> Optional[str]:
        """当 SQL 执行失败时,让 LLM 修正"""
        fix_prompt = f"""
以下 SQL 执行失败,请修正:

失败的 SQL:
```sql
{failed_sql}

错误信息:{error_msg}

数据库 Schema: {schema_context}

请输出修正后的 SQL(只输出 SQL,不要解释): """ return self._call_llm(fix_prompt)

def _suggest_chart(self, result: dict) -> dict:
    """根据查询结果建议合适的图表类型"""
    if not result or not result.get('rows'):
        return {'type': 'none', 'reason': '无数据'}
    
    columns = result['columns']
    row_count = len(result['rows'])
    
    # 检测数据特征
    has_date_column = any(
        col.lower() in ['date', 'month', 'week', 'sale_date', 'created_at']
        for col in columns
    )
    has_numeric_columns = sum(
        1 for col in columns 
        if isinstance(result['rows'][0].get(col), (int, float))
    )
    has_category_column = any(
        col.lower() in ['category', 'region', 'province', 'store_type', 'product_name']
        for col in columns
    )
    
    # 图表类型建议逻辑
    if has_date_column and has_numeric_columns >= 1 and row_count > 5:
        return {
            'type': 'line',
            'reason': '时间序列数据,折线图最适合展示趋势',
            'x_axis': next(c for c in columns if c.lower() in ['date', 'month', 'week', 'sale_date']),
            'y_axis': [c for c in columns if isinstance(result['rows'][0].get(c), (int, float))]
        }
    
    elif has_category_column and has_numeric_columns >= 1 and row_count <= 20:
        return {
            'type': 'bar',
            'reason': '分类比较数据,柱状图清晰展示对比',
            'x_axis': next(c for c in columns if c.lower() in 
                           ['category', 'region', 'province', 'store_type', 'product_name']),
            'y_axis': [c for c in columns if isinstance(result['rows'][0].get(c), (int, float))][:1]
        }
    
    elif has_numeric_columns == 1 and row_count <= 10 and has_category_column:
        return {
            'type': 'pie',
            'reason': '占比数据,饼图直观展示份额',
        }
    
    else:
        return {
            'type': 'table',
            'reason': '数据较复杂,表格展示最全面'
        }

class SchemaManager: """管理数据库 Schema,支持动态获取"""

def __init__(self, db_config: dict):
    self.db_config = db_config
    self._schema_cache = {}

def get_schema_context(self, tables: list) -> str:
    """获取指定表的 Schema 上下文"""
    schema_parts = []
    
    for table in tables:
        if table not in self._schema_cache:
            self._schema_cache[table] = self._fetch_schema(table)
        schema_parts.append(self._schema_cache[table])
    
    return '\n\n'.join(schema_parts)

def _fetch_schema(self, table: str) -> str:
    """从数据库动态获取表结构"""
    connection = mysql.connector.connect(**self.db_config)
    cursor = connection.cursor()
    
    # 获取列信息
    cursor.execute(f"DESCRIBE `{table}`")
    columns = cursor.fetchall()
    
    # 获取索引信息
    cursor.execute(f"SHOW INDEX FROM `{table}`")
    indexes = cursor.fetchall()
    
    # 获取样本数据
    cursor.execute(f"SELECT * FROM `{table}` LIMIT 3")
    sample_rows = cursor.fetchall()
    
    connection.close()
    
    # 格式化输出
    schema_text = f"## 表:{table}\n"
    schema_text += "| 字段名 | 类型 | 说明 |\n|--------|------|------|\n"
    for col in columns:
        schema_text += f"| {col[0]} | {col[1]} | |\n"
    
    return schema_text

### 图表生成与报告自动化

```python
# report_generator.py — 自动化报告生成

import plotly.graph_objects as go
import plotly.express as px
import base64
from io import BytesIO

class ReportGenerator:
    
    def __init__(self, nl2sql_service: NL2SQLService):
        self.nl2sql = nl2sql_service
    
    def generate_weekly_report(self, week_start: str, week_end: str) -> dict:
        """
        自动生成周报
        
        Args:
            week_start: 周开始日期(YYYY-MM-DD)
            week_end: 周结束日期(YYYY-MM-DD)
        """
        
        # 预定义的周报查询(关键指标)
        report_queries = [
            {
                'title': '本周整体销售概况',
                'question': f'{week_start} 到 {week_end} 的总销售额、订单量、客单价',
                'chart_type': 'kpi'
            },
            {
                'title': '各大区销售额对比',
                'question': f'{week_start} 到 {week_end} 各大区销售额排名',
                'chart_type': 'bar'
            },
            {
                'title': '日销售额趋势',
                'question': f'{week_start} 到 {week_end} 每日销售额趋势',
                'chart_type': 'line'
            },
            {
                'title': 'Top 10 热销商品',
                'question': f'{week_start} 到 {week_end} 销量最高的10个商品',
                'chart_type': 'bar'
            },
            {
                'title': '品类销售占比',
                'question': f'{week_start} 到 {week_end} 各品类销售额占比',
                'chart_type': 'pie'
            },
        ]
        
        report_sections = []
        
        for query_config in report_queries:
            # 执行查询
            result = self.nl2sql.process_query(query_config['question'])
            
            if 'error' in result:
                report_sections.append({
                    'title': query_config['title'],
                    'error': result['error']
                })
                continue
            
            # 生成图表
            chart_data = self._generate_chart(
                result['result'],
                query_config['chart_type'],
                query_config['title']
            )
            
            # 生成文字分析
            analysis = self._generate_analysis(
                query_config['question'],
                result['result'],
                result.get('interpretation', '')
            )
            
            report_sections.append({
                'title': query_config['title'],
                'chart': chart_data,
                'data': result['result'],
                'analysis': analysis,
                'sql': result['sql']
            })
        
        # 生成执行摘要(由 LLM 综合所有数据)
        executive_summary = self._generate_executive_summary(
            week_start, week_end, report_sections
        )
        
        return {
            'report_type': 'weekly',
            'period': f'{week_start} 至 {week_end}',
            'generated_at': datetime.now().isoformat(),
            'executive_summary': executive_summary,
            'sections': report_sections
        }
    
    def _generate_chart(self, data: dict, chart_type: str, title: str) -> str:
        """生成 Plotly 图表,返回 base64 编码的图片"""
        
        if not data or not data.get('rows'):
            return None
        
        rows = data['rows']
        columns = data['columns']
        
        if chart_type == 'bar' and len(columns) >= 2:
            x_col = columns[0]
            y_col = columns[1]
            
            fig = px.bar(
                rows,
                x=x_col,
                y=y_col,
                title=title,
                color=y_col,
                color_continuous_scale='Blues'
            )
            fig.update_layout(
                font=dict(family='Microsoft YaHei, Arial', size=12),
                title_font_size=16,
                height=400
            )
        
        elif chart_type == 'line' and len(columns) >= 2:
            x_col = columns[0]
            y_col = columns[1]
            
            fig = px.line(
                rows,
                x=x_col,
                y=y_col,
                title=title,
                markers=True
            )
        
        elif chart_type == 'pie' and len(columns) >= 2:
            label_col = columns[0]
            value_col = columns[1]
            
            fig = px.pie(
                rows,
                names=label_col,
                values=value_col,
                title=title
            )
        
        else:
            return None
        
        # 导出为 PNG
        img_bytes = fig.to_image(format='png', width=800, height=400, scale=2)
        return base64.b64encode(img_bytes).decode('utf-8')
    
    def _generate_executive_summary(
        self, week_start: str, week_end: str, sections: list
    ) -> str:
        """用 LLM 综合生成执行摘要"""
        
        # 收集关键数据
        summary_data = {
            section['title']: {
                'analysis': section.get('analysis', ''),
                'data_preview': section.get('data', {}).get('rows', [])[:3]
            }
            for section in sections
            if 'error' not in section
        }
        
        prompt = f"""
你是一位资深的零售数据分析师。请基于以下数据,为 {week_start} 到 {week_end} 的周报写一份300字左右的执行摘要。

数据摘要:
{json.dumps(summary_data, ensure_ascii=False, indent=2)}

要求:
1. 突出最重要的 3-5 个业务洞察
2. 明确指出哪些指标表现超预期,哪些需要关注
3. 给出 2-3 条可操作的建议
4. 语言简洁专业,避免废话

执行摘要:
"""
        return self._call_llm(prompt)
    
    def export_to_markdown(self, report: dict) -> str:
        """将报告导出为 Markdown 格式"""
        md = f"# {report['period']} 周报\n\n"
        md += f"**生成时间**:{report['generated_at']}\n\n"
        md += "## 执行摘要\n\n"
        md += report['executive_summary'] + "\n\n"
        md += "---\n\n"
        
        for section in report['sections']:
            md += f"## {section['title']}\n\n"
            if 'error' in section:
                md += f"⚠️ 数据获取失败:{section['error']}\n\n"
                continue
            
            md += section.get('analysis', '') + "\n\n"
            
            # 添加数据表格
            data = section.get('data', {})
            if data.get('rows') and data.get('columns'):
                md += self._rows_to_markdown_table(data['columns'], data['rows'][:10])
            
            md += "\n"
        
        return md
    
    def _rows_to_markdown_table(self, columns: list, rows: list) -> str:
        """将查询结果转换为 Markdown 表格"""
        header = "| " + " | ".join(columns) + " |"
        separator = "| " + " | ".join(["---"] * len(columns)) + " |"
        
        data_rows = []
        for row in rows:
            values = [str(row.get(col, '')) for col in columns]
            data_rows.append("| " + " | ".join(values) + " |")
        
        return "\n".join([header, separator] + data_rows) + "\n"

Level 3:源码与原理(5 年以上)

LLM 生成 SQL 的准确率提升技巧

纯粹依赖 LLM 生成 SQL 的准确率大约在 70-80%,需要以下技术提升到 90%+:

技巧1:Few-Shot Examples(少样本示例)

在 Prompt 中加入 5-10 个问答对示例:

FEW_SHOT_EXAMPLES = [
    {
        "question": "上周各大区的销售额是多少?",
        "sql": """
SELECT 
    s.region,
    ROUND(SUM(ds.total_amount), 2) AS 销售额,
    COUNT(*) AS 订单数
FROM daily_sales ds
JOIN stores s ON ds.store_id = s.id
WHERE YEARWEEK(ds.sale_date) = YEARWEEK(NOW()) - 1
GROUP BY s.region
ORDER BY 销售额 DESC;
"""
    },
    {
        "question": "本月销量排名前5的商品",
        "sql": """
SELECT 
    product_name,
    SUM(quantity) AS 总销量,
    ROUND(SUM(total_amount), 2) AS 销售额
FROM daily_sales
WHERE sale_date >= DATE_FORMAT(NOW(), '%Y-%m-01')
GROUP BY product_id, product_name
ORDER BY 总销量 DESC
LIMIT 5;
"""
    },
]

def build_prompt_with_examples(question: str, schema: str) -> str:
    examples_text = "\n\n".join([
        f"问题:{ex['question']}\nSQL:{ex['sql']}"
        for ex in FEW_SHOT_EXAMPLES
    ])
    
    return f"""
{schema}

## 示例:
{examples_text}

## 现在回答:
问题:{question}
SQL:
"""

技巧2:Think Step by Step(链式思维)

COT_PROMPT_SUFFIX = """
在生成 SQL 之前,请先分析:
1. 这个问题涉及哪些数据维度?(时间/地域/商品/门店)
2. 需要聚合计算哪些指标?(SUM/COUNT/AVG)
3. 需要关联哪些表?
4. 时间范围如何精确表达?

分析完成后,输出 SQL:
"""

技巧3:自校验(Self-Consistency)

async def generate_sql_with_self_check(question: str, schema: str) -> str:
    """生成 SQL 并进行自校验"""
    
    # 第一次生成
    sql_1 = await generate_sql(question, schema)
    
    # 让 LLM 检查自己生成的 SQL
    check_prompt = f"""
请检查以下 SQL 是否正确回答了用户问题:

用户问题:{question}

SQL:
{sql_1}

请检查:
1. SQL 的逻辑是否符合问题的意图?
2. 时间范围的表达是否正确?
3. 聚合方式是否正确?
4. 是否有语法错误?

如果有问题,请给出修正后的 SQL;如果没有问题,请输出"CORRECT":
"""
    
    check_result = await call_llm(check_prompt)
    
    if check_result.strip().upper() != 'CORRECT':
        # 提取修正后的 SQL
        corrected_sql = extract_sql_from_text(check_result)
        return corrected_sql or sql_1
    
    return sql_1

多轮对话中的 NL2SQL

在对话中,用户可能进行追问,如:

class ConversationalNL2SQL:
    """支持多轮对话的 NL2SQL"""
    
    def __init__(self, nl2sql_service: NL2SQLService):
        self.nl2sql = nl2sql_service
        self.conversation_history = []
    
    def process(self, user_question: str) -> dict:
        # 如果有历史对话,生成上下文感知的问题
        if self.conversation_history:
            contextual_question = self._rewrite_with_context(
                user_question, 
                self.conversation_history
            )
        else:
            contextual_question = user_question
        
        # 执行查询
        result = self.nl2sql.process_query(contextual_question)
        
        # 更新历史
        self.conversation_history.append({
            'user': user_question,
            'sql': result.get('sql'),
            'summary': result.get('interpretation', '')[:200]
        })
        
        # 保留最近5轮
        if len(self.conversation_history) > 5:
            self.conversation_history = self.conversation_history[-5:]
        
        return result
    
    def _rewrite_with_context(self, question: str, history: list) -> str:
        """将追问改写为独立的完整问题"""
        context = "\n".join([
            f"Q: {h['user']}\nSQL摘要: {h['summary']}"
            for h in history[-3:]
        ])
        
        rewrite_prompt = f"""
基于以下对话历史,将用户的追问改写为独立的完整问题:

历史对话:
{context}

用户追问:{question}

改写后的完整问题(直接输出问题文本):
"""
        return self._call_llm(rewrite_prompt)

Level 4:生产陷阱与决策(专家视角)

陷阱1:LLM 生成了危险的全表扫描 SQL

症状:用户问"哪些商品从来没卖出去过",LLM 生成了无限制的全表扫描:

-- 危险!2亿行全表扫描,会让数据库瘫痪
SELECT * FROM daily_sales WHERE product_id NOT IN (
    SELECT DISTINCT product_id FROM daily_sales
);

预防措施:

  1. 语句级别限制:所有查询强制添加 LIMIT 和超时:
def add_protection(sql: str, max_rows: int = 1000, timeout_ms: int = 30000) -> str:
    # 添加 SQL 超时提示(MySQL)
    protected = f"/*+ MAX_EXECUTION_TIME({timeout_ms}) */ {sql}"
    
    # 如果没有 LIMIT,添加
    if 'LIMIT' not in sql.upper():
        protected = protected.rstrip(';') + f" LIMIT {max_rows}"
    
    return protected
  1. 使用只读副本:NL2SQL 查询必须路由到只读从库,保护主库:
# 使用 MySQLdb 的 connect_args 强制只读
connection = mysql.connector.connect(
    host=REPLICA_HOST,  # 从库地址
    user='readonly_user',
    ...
)
  1. 慢查询监控:超过 5 秒的查询自动告警并终止:
import threading

def execute_with_timeout(cursor, sql: str, timeout: int = 30) -> list:
    """带超时的 SQL 执行"""
    result = []
    error = []
    
    def target():
        try:
            cursor.execute(sql)
            result.extend(cursor.fetchall())
        except Exception as e:
            error.append(str(e))
    
    thread = threading.Thread(target=target)
    thread.start()
    thread.join(timeout=timeout)
    
    if thread.is_alive():
        cursor.execute("KILL QUERY " + str(cursor.connection.connection_id))
        raise TimeoutError(f"查询超时(>{timeout}秒),请优化查询条件")
    
    if error:
        raise Exception(error[0])
    
    return result

陷阱2:Schema 泄露风险

NL2SQL 需要向 LLM 发送完整的表结构,这可能泄露敏感的业务数据结构。

防范措施:

  1. Schema 脱敏:不发送真实的列注释中的敏感信息
  2. 使用 Azure OpenAI 或私有化部署的 LLM:数据不出企业
  3. 只发送必要的 Schema:根据问题智能选择相关表

陷阱3:结果解读的幻觉问题

用户问"北京门店的销售额",数据库返回 ¥1,234,567,LLM 解读时可能说"北京门店本月销售额为 123.4 万元,同比增长 15%"——但"同比增长 15%"是 LLM 瞎编的!

解决:强制约束 LLM 解读只能基于实际返回数据:

def interpret_results_safely(question: str, sql: str, result: dict) -> str:
    data_json = json.dumps(result['rows'][:20], ensure_ascii=False)
    
    prompt = f"""
用户问题:{question}
查询的 SQL:{sql}
查询结果:{data_json}

请用自然语言解读以上查询结果。

严格规则:
1. 只能基于查询结果中的数据进行描述
2. 禁止推断或计算查询结果中不存在的指标(如同比、环比)
3. 如果数据中没有对比数据,不要提及"增长"或"下降"
4. 如果查询结果为空,请说明"未找到符合条件的数据"

解读:
"""
    return call_llm(prompt)

实际业务效果

该连锁零售集团上线 3 个月后:

指标 上线前 上线后 改善
数据团队日均处理查询数 150 35 -77%
业务人员等待时间 平均 4 小时 < 2 分钟 -97%
周报制作时间 8 小时 45 分钟 -91%
数据分析师 vs 高价值分析时间占比 30% 70% +133%
NL2SQL 准确率(正确生成可执行SQL) - 91% -
用户满意度(自助查询) - 4.3/5 -

本章小结

数据分析助手的核心技术要点

  1. Schema 上下文是 NL2SQL 的关键:LLM 不了解你的业务数据结构,精心设计的 Schema 文档(含示例数据、业务规则)是准确率的基础。

  2. 安全性不可妥协:只读账号、SQL 白名单(只允许 SELECT)、超时机制、行数限制——四道防线缺一不可。

  3. Few-Shot + CoT 能把准确率从 75% 提升到 90%+:少样本示例和链式思维提示是实用且有效的优化手段。

  4. 报告自动化 = 预定义查询 + LLM 洞察:周报/月报不是"让 AI 自由发挥",而是围绕预定义的关键指标查询,再由 LLM 生成洞察和摘要。

  5. 结果解读必须严格约束:防止 LLM 在解读数据时添加不存在的同比/环比数据——这种幻觉在商业决策场景极度危险。

  6. 多轮对话需要问题改写:追问需要先改写为独立的完整问题,再走 NL2SQL 流程,这是保证上下文一致性的正确做法。

NL2SQL 准确率提升路线图

阶段 技术 预期准确率
基础版 LLM + Schema 70-75%
进阶版 + Few-Shot + CoT 85-90%
生产版 + 自校验 + 失败自修正 90-93%
专家版 + 向量化 Schema 检索 + 微调 95%+
本章评分
4.6  / 5  (5 评分)

💬 留言讨论