Skip to content

8.2 知识库问答系统(RAG 实战)

完整实现:文档上传 → 多格式解析 → 智能分块 → 向量化 → 混合检索 + Reranker → 引用溯源生成。

难度:⭐⭐⭐⭐ | 预计时长:3-4 周


文档问答助手是 RAG 技术的典型应用场景,核心挑战是:多格式文档解析、智能分块策略、精准检索、答案引用溯源、处理跨文档关联问题。本节构建一个支持多格式文档的企业知识库问答系统。

系统架构

                      ┌──────────────────────────┐
                      │   用户提问               │
                      └────────────┬─────────────┘

                 ┌─────────────────┼─────────────────┐
                 │                 │                 │
         ┌───────▼──────┐  ┌──────▼──────┐  ┌──────▼──────┐
         │ 问题重写     │  │ 查询扩展    │  │ 意图识别    │
         │ (Query       │  │ (关键词提取)│  │             │
         │  Rewriting)  │  │             │  │             │
         └───────┬──────┘  └──────┬──────┘  └──────┬──────┘
                 └─────────────────┼─────────────────┘

                          ┌────────▼─────────┐
                          │  混合检索引擎    │
                          │  - 向量检索      │
                          │  - BM25 全文检索 │
                          │  - 元数据过滤    │
                          └────────┬─────────┘

                          ┌────────▼─────────┐
                          │  重排序(Rerank)  │
                          │  - Cross-Encoder │
                          │  - Cohere/BGE    │
                          └────────┬─────────┘

                          ┌────────▼─────────┐
                          │  上下文构建      │
                          │  - 引用标注      │
                          │  - 去重与合并    │
                          └────────┬─────────┘

                          ┌────────▼─────────┐
                          │   LLM 生成答案   │
                          │   + 引用标注     │
                          └────────┬─────────┘

                          ┌────────▼─────────┐
                          │  后处理与溯源    │
                          │  - 答案验证      │
                          │  - 引用格式化    │
                          └──────────────────┘

8.2.1 多格式文档解析

核心库选择

文档类型推荐库优点注意事项
PDFpypdf, pdfplumber, PyMuPDF支持文本+表格+图片提取扫描件 PDF 需 OCR
Wordpython-docx, docx2txt保留格式与结构docx2txt 更轻量
Excelopenpyxl, pandas表格结构化处理大文件用 read_only=True
PPTpython-pptx提取幻灯片文本与备注图片需单独处理
Markdownmistune, markdown-it-py解析为 AST 树保留代码块格式
HTMLBeautifulSoup4, trafilatura网页正文提取trafilatura 去噪能力强
图片 OCRPaddleOCR, EasyOCR中英文混合识别GPU 加速效果显著

统一文档解析器

python
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional
import mimetypes

@dataclass
class DocumentChunk:
    """文档分块单元"""
    content: str              # 文本内容
    metadata: dict            # 元数据(来源、页码、标题等)
    chunk_id: str            # 全局唯一 ID
    embedding: Optional[List[float]] = None

class UniversalDocumentParser:
    """多格式文档解析器"""

    def __init__(self):
        self.parsers = {
            'application/pdf': self._parse_pdf,
            'application/vnd.openxmlformats-officedocument.wordprocessingml.document': self._parse_docx,
            'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet': self._parse_xlsx,
            'application/vnd.openxmlformats-officedocument.presentationml.presentation': self._parse_pptx,
            'text/markdown': self._parse_markdown,
            'text/html': self._parse_html,
            'image/png': self._parse_image_ocr,
            'image/jpeg': self._parse_image_ocr,
        }

    def parse(self, file_path: str) -> List[DocumentChunk]:
        """解析文档并返回分块列表"""
        mime_type, _ = mimetypes.guess_type(file_path)

        if mime_type not in self.parsers:
            raise ValueError(f"不支持的文档类型: {mime_type}")

        parser_func = self.parsers[mime_type]
        return parser_func(file_path)

    def _parse_pdf(self, file_path: str) -> List[DocumentChunk]:
        """解析 PDF(支持文本+表格+图片)"""
        import pdfplumber
        from PIL import Image
        import io

        chunks = []
        with pdfplumber.open(file_path) as pdf:
            for page_num, page in enumerate(pdf.pages, start=1):
                # 提取文本
                text = page.extract_text() or ""

                # 提取表格
                tables = page.extract_tables()
                for table in tables:
                    table_text = "\n".join([" | ".join(row) for row in table if row])
                    text += f"\n\n[表格]\n{table_text}\n"

                # 提取图片(如果需要 OCR)
                for img_index, img_obj in enumerate(page.images):
                    try:
                        # 这里可以添加图片 OCR 逻辑
                        text += f"\n[图片 {img_index + 1}]\n"
                    except Exception:
                        pass

                if text.strip():
                    chunks.append(DocumentChunk(
                        content=text.strip(),
                        metadata={
                            "source": Path(file_path).name,
                            "page": page_num,
                            "type": "pdf"
                        },
                        chunk_id=f"{Path(file_path).stem}_page_{page_num}"
                    ))

        return chunks

    def _parse_docx(self, file_path: str) -> List[DocumentChunk]:
        """解析 Word 文档"""
        import docx

        doc = docx.Document(file_path)
        chunks = []
        current_section = ""
        section_content = []

        for para in doc.paragraphs:
            text = para.text.strip()
            if not text:
                continue

            # 检测标题(作为分块边界)
            if para.style.name.startswith('Heading'):
                if section_content:
                    chunks.append(DocumentChunk(
                        content="\n".join(section_content),
                        metadata={
                            "source": Path(file_path).name,
                            "section": current_section,
                            "type": "docx"
                        },
                        chunk_id=f"{Path(file_path).stem}_{len(chunks)}"
                    ))
                    section_content = []
                current_section = text

            section_content.append(text)

        # 添加最后一个章节
        if section_content:
            chunks.append(DocumentChunk(
                content="\n".join(section_content),
                metadata={
                    "source": Path(file_path).name,
                    "section": current_section,
                    "type": "docx"
                },
                chunk_id=f"{Path(file_path).stem}_{len(chunks)}"
            ))

        # 解析表格
        for table_idx, table in enumerate(doc.tables):
            table_text = []
            for row in table.rows:
                row_text = " | ".join([cell.text.strip() for cell in row.cells])
                table_text.append(row_text)

            chunks.append(DocumentChunk(
                content=f"[表格 {table_idx + 1}]\n" + "\n".join(table_text),
                metadata={
                    "source": Path(file_path).name,
                    "type": "docx_table",
                    "table_index": table_idx + 1
                },
                chunk_id=f"{Path(file_path).stem}_table_{table_idx}"
            ))

        return chunks

    def _parse_xlsx(self, file_path: str) -> List[DocumentChunk]:
        """解析 Excel 表格"""
        import pandas as pd

        chunks = []
        excel_file = pd.ExcelFile(file_path)

        for sheet_name in excel_file.sheet_names:
            df = pd.read_excel(file_path, sheet_name=sheet_name)

            # 转换为 Markdown 表格格式
            table_md = df.to_markdown(index=False)

            chunks.append(DocumentChunk(
                content=f"# {sheet_name}\n\n{table_md}",
                metadata={
                    "source": Path(file_path).name,
                    "sheet": sheet_name,
                    "type": "xlsx",
                    "rows": len(df),
                    "columns": len(df.columns)
                },
                chunk_id=f"{Path(file_path).stem}_{sheet_name}"
            ))

        return chunks

    def _parse_pptx(self, file_path: str) -> List[DocumentChunk]:
        """解析 PowerPoint"""
        from pptx import Presentation

        prs = Presentation(file_path)
        chunks = []

        for slide_num, slide in enumerate(prs.slides, start=1):
            text_parts = []

            # 提取标题
            if slide.shapes.title:
                text_parts.append(f"# {slide.shapes.title.text}")

            # 提取文本框内容
            for shape in slide.shapes:
                if hasattr(shape, "text") and shape.text.strip():
                    text_parts.append(shape.text.strip())

            # 提取备注
            if slide.has_notes_slide:
                notes_text = slide.notes_slide.notes_text_frame.text.strip()
                if notes_text:
                    text_parts.append(f"\n[演讲备注]\n{notes_text}")

            if text_parts:
                chunks.append(DocumentChunk(
                    content="\n\n".join(text_parts),
                    metadata={
                        "source": Path(file_path).name,
                        "slide": slide_num,
                        "type": "pptx"
                    },
                    chunk_id=f"{Path(file_path).stem}_slide_{slide_num}"
                ))

        return chunks

    def _parse_markdown(self, file_path: str) -> List[DocumentChunk]:
        """解析 Markdown(按标题分块)"""
        import re

        with open(file_path, 'r', encoding='utf-8') as f:
            content = f.read()

        # 按一级或二级标题分块
        sections = re.split(r'\n(?=#{1,2}\s)', content)
        chunks = []

        for idx, section in enumerate(sections):
            if section.strip():
                # 提取标题作为元数据
                title_match = re.match(r'^(#{1,2})\s+(.+)$', section.split('\n')[0])
                title = title_match.group(2) if title_match else f"Section {idx + 1}"

                chunks.append(DocumentChunk(
                    content=section.strip(),
                    metadata={
                        "source": Path(file_path).name,
                        "section": title,
                        "type": "markdown"
                    },
                    chunk_id=f"{Path(file_path).stem}_{idx}"
                ))

        return chunks

    def _parse_html(self, file_path: str) -> List[DocumentChunk]:
        """解析 HTML(提取正文)"""
        import trafilatura

        with open(file_path, 'r', encoding='utf-8') as f:
            html = f.read()

        # 使用 trafilatura 提取正文(自动去除导航、广告等)
        text = trafilatura.extract(html, include_tables=True, include_comments=False)

        if text:
            return [DocumentChunk(
                content=text,
                metadata={
                    "source": Path(file_path).name,
                    "type": "html"
                },
                chunk_id=Path(file_path).stem
            )]
        return []

    def _parse_image_ocr(self, file_path: str) -> List[DocumentChunk]:
        """解析图片(OCR 文字识别)"""
        from paddleocr import PaddleOCR

        ocr = PaddleOCR(use_angle_cls=True, lang='ch', use_gpu=False)
        result = ocr.ocr(file_path, cls=True)

        if not result or not result[0]:
            return []

        # 提取识别的文本
        text_lines = [line[1][0] for line in result[0]]
        text = "\n".join(text_lines)

        return [DocumentChunk(
            content=text,
            metadata={
                "source": Path(file_path).name,
                "type": "image_ocr",
                "confidence": sum([line[1][1] for line in result[0]]) / len(result[0])
            },
            chunk_id=Path(file_path).stem
        )]

8.2.2 智能分块策略

分块策略对比

策略适用场景优点缺点
固定长度通用文本简单高效可能切断语义
语义分块长文档保持语义完整计算成本高
结构化分块有层级的文档(Markdown、Word)保留结构信息依赖文档格式
滑动窗口需要上下文重叠检索召回率高存储成本增加

高级分块实现

python
from typing import List
from langchain.text_splitter import RecursiveCharacterTextSplitter
import tiktoken

class SmartChunker:
    """智能文档分块器"""

    def __init__(
        self,
        chunk_size: int = 800,      # 目标块大小(tokens)
        chunk_overlap: int = 200,    # 重叠 tokens 数
        model_name: str = "gpt-4o"
    ):
        self.chunk_size = chunk_size
        self.chunk_overlap = chunk_overlap
        self.encoding = tiktoken.encoding_for_model(model_name)

    def chunk_by_semantic(self, chunks: List[DocumentChunk]) -> List[DocumentChunk]:
        """语义分块:利用嵌入模型识别语义边界"""
        from langchain_experimental.text_splitter import SemanticChunker
        from langchain_openai import OpenAIEmbeddings

        embeddings = OpenAIEmbeddings(model="text-embedding-3-small")

        semantic_chunker = SemanticChunker(
            embeddings=embeddings,
            breakpoint_threshold_type="percentile"  # 或 "standard_deviation"
        )

        refined_chunks = []
        for chunk in chunks:
            # 对长文本进行语义分块
            if self._count_tokens(chunk.content) > self.chunk_size * 1.5:
                sub_docs = semantic_chunker.create_documents([chunk.content])
                for idx, sub_doc in enumerate(sub_docs):
                    refined_chunks.append(DocumentChunk(
                        content=sub_doc.page_content,
                        metadata={**chunk.metadata, "sub_chunk": idx + 1},
                        chunk_id=f"{chunk.chunk_id}_semantic_{idx}"
                    ))
            else:
                refined_chunks.append(chunk)

        return refined_chunks

    def chunk_with_overlap(self, chunks: List[DocumentChunk]) -> List[DocumentChunk]:
        """滑动窗口分块(增加召回率)"""
        splitter = RecursiveCharacterTextSplitter(
            chunk_size=self.chunk_size,
            chunk_overlap=self.chunk_overlap,
            length_function=self._count_tokens,
            separators=["\n\n", "\n", "。", "!", "?", ".", "!", "?", " ", ""]
        )

        refined_chunks = []
        for chunk in chunks:
            if self._count_tokens(chunk.content) > self.chunk_size:
                texts = splitter.split_text(chunk.content)
                for idx, text in enumerate(texts):
                    refined_chunks.append(DocumentChunk(
                        content=text,
                        metadata={**chunk.metadata, "sub_chunk": idx + 1},
                        chunk_id=f"{chunk.chunk_id}_part_{idx}"
                    ))
            else:
                refined_chunks.append(chunk)

        return refined_chunks

    def _count_tokens(self, text: str) -> int:
        """计算文本 token 数"""
        return len(self.encoding.encode(text))

    def add_metadata_summary(self, chunks: List[DocumentChunk]) -> List[DocumentChunk]:
        """为每个分块生成元数据摘要(提升检索精度)"""
        from openai import OpenAI

        client = OpenAI()

        for chunk in chunks:
            # 为长文本块生成摘要(用于元数据)
            if self._count_tokens(chunk.content) > 500:
                try:
                    response = client.chat.completions.create(
                        model="gpt-4o-mini",
                        messages=[
                            {"role": "system", "content": "用一句话概括下面文本的核心内容(30字以内):"},
                            {"role": "user", "content": chunk.content[:1000]}
                        ],
                        max_tokens=100
                    )
                    chunk.metadata["summary"] = response.choices[0].message.content.strip()
                except Exception as e:
                    print(f"生成摘要失败: {e}")

        return chunks

8.2.3 混合检索 + 重排序

检索管道架构

python
from typing import List, Tuple
from dataclasses import dataclass
import numpy as np

@dataclass
class RetrievalResult:
    chunk: DocumentChunk
    score: float
    retrieval_method: str  # "vector", "bm25", "hybrid"

class HybridRetriever:
    """混合检索器:向量检索 + BM25 + 重排序"""

    def __init__(
        self,
        vector_store,           # 向量数据库(ChromaDB/Pinecone/Qdrant)
        bm25_index=None,        # BM25 索引
        reranker_model: str = "BAAI/bge-reranker-large"
    ):
        self.vector_store = vector_store
        self.bm25_index = bm25_index

        # 加载重排序模型
        from sentence_transformers import CrossEncoder
        self.reranker = CrossEncoder(reranker_model)

    def retrieve(
        self,
        query: str,
        top_k: int = 20,
        rerank_top_k: int = 5,
        vector_weight: float = 0.7,
        use_rerank: bool = True,
        filters: dict = None
    ) -> List[RetrievalResult]:
        """
        混合检索流程:
        1. 向量检索 + BM25 并行检索
        2. RRF 融合排序
        3. 重排序(Reranker)
        """

        # 1. 向量检索
        vector_results = self._vector_search(query, top_k, filters)

        # 2. BM25 全文检索
        bm25_results = self._bm25_search(query, top_k) if self.bm25_index else []

        # 3. RRF 融合(Reciprocal Rank Fusion)
        hybrid_results = self._rrf_fusion(
            vector_results,
            bm25_results,
            k=60,  # RRF 超参数
            vector_weight=vector_weight
        )

        # 4. 重排序
        if use_rerank and len(hybrid_results) > rerank_top_k:
            hybrid_results = self._rerank(query, hybrid_results, rerank_top_k)

        return hybrid_results[:rerank_top_k]

    def _vector_search(
        self,
        query: str,
        top_k: int,
        filters: dict = None
    ) -> List[RetrievalResult]:
        """向量检索"""
        results = self.vector_store.similarity_search_with_score(
            query,
            k=top_k,
            filter=filters
        )

        return [
            RetrievalResult(
                chunk=DocumentChunk(
                    content=doc.page_content,
                    metadata=doc.metadata,
                    chunk_id=doc.metadata.get("chunk_id", "")
                ),
                score=float(score),
                retrieval_method="vector"
            )
            for doc, score in results
        ]

    def _bm25_search(self, query: str, top_k: int) -> List[RetrievalResult]:
        """BM25 全文检索"""
        from rank_bm25 import BM25Okapi

        if not self.bm25_index:
            return []

        tokenized_query = query.split()
        scores = self.bm25_index.get_scores(tokenized_query)
        top_indices = np.argsort(scores)[::-1][:top_k]

        return [
            RetrievalResult(
                chunk=self.bm25_index.corpus[idx],
                score=float(scores[idx]),
                retrieval_method="bm25"
            )
            for idx in top_indices
            if scores[idx] > 0
        ]

    def _rrf_fusion(
        self,
        vector_results: List[RetrievalResult],
        bm25_results: List[RetrievalResult],
        k: int = 60,
        vector_weight: float = 0.7
    ) -> List[RetrievalResult]:
        """
        RRF 融合:fused_score = Σ (1 / (k + rank_i))
        k 通常取 60,是经验值
        """
        score_map = {}

        # 向量检索结果
        for rank, result in enumerate(vector_results):
            chunk_id = result.chunk.chunk_id
            score_map[chunk_id] = score_map.get(chunk_id, 0) + \
                                  vector_weight * (1 / (k + rank + 1))

        # BM25 结果
        for rank, result in enumerate(bm25_results):
            chunk_id = result.chunk.chunk_id
            score_map[chunk_id] = score_map.get(chunk_id, 0) + \
                                  (1 - vector_weight) * (1 / (k + rank + 1))

        # 合并结果并排序
        all_chunks = {r.chunk.chunk_id: r.chunk for r in vector_results + bm25_results}

        fused_results = [
            RetrievalResult(
                chunk=all_chunks[chunk_id],
                score=score,
                retrieval_method="hybrid"
            )
            for chunk_id, score in score_map.items()
        ]

        return sorted(fused_results, key=lambda x: x.score, reverse=True)

    def _rerank(
        self,
        query: str,
        results: List[RetrievalResult],
        top_k: int
    ) -> List[RetrievalResult]:
        """重排序:使用 Cross-Encoder 重新打分"""
        pairs = [[query, r.chunk.content] for r in results]
        rerank_scores = self.reranker.predict(pairs)

        # 更新分数
        for idx, score in enumerate(rerank_scores):
            results[idx].score = float(score)
            results[idx].retrieval_method = "reranked"

        return sorted(results, key=lambda x: x.score, reverse=True)[:top_k]

8.2.4 答案生成与引用溯源

带引用的答案生成

python
from typing import List, Dict
from openai import OpenAI

class CitationGenerator:
    """引用溯源的答案生成器"""

    def __init__(self, model: str = "gpt-4o"):
        self.client = OpenAI()
        self.model = model

    def generate_with_citations(
        self,
        question: str,
        retrieved_chunks: List[RetrievalResult],
        conversation_history: List[dict] = None
    ) -> Dict:
        """
        生成答案并标注引用来源
        返回:{answer, citations, sources}
        """

        # 构建上下文(带引用编号)
        context_parts = []
        sources = []

        for idx, result in enumerate(retrieved_chunks, start=1):
            chunk = result.chunk
            citation_id = f"[{idx}]"

            context_parts.append(
                f"{citation_id} {chunk.content}\n"
                f"来源: {chunk.metadata.get('source', 'Unknown')}, "
                f"页码: {chunk.metadata.get('page', 'N/A')}"
            )

            sources.append({
                "id": idx,
                "source": chunk.metadata.get("source", ""),
                "page": chunk.metadata.get("page"),
                "section": chunk.metadata.get("section"),
                "score": result.score
            })

        context = "\n\n".join(context_parts)

        # 构建 Prompt
        system_prompt = """你是一个专业的文档问答助手。请基于提供的文档片段回答用户问题。

**重要规则**:
1. 只使用提供的文档内容回答,不要编造信息
2. 在答案中使用 [数字] 标注引用来源(如:根据文档 [1],公司成立于...)
3. 如果文档中没有相关信息,明确告知"文档中未找到相关信息"
4. 答案要准确、简洁、结构化

文档内容:
{context}
"""

        messages = [
            {"role": "system", "content": system_prompt.format(context=context)}
        ]

        # 添加历史对话(多轮问答)
        if conversation_history:
            messages.extend(conversation_history)

        messages.append({"role": "user", "content": question})

        # 调用 LLM
        response = self.client.chat.completions.create(
            model=self.model,
            messages=messages,
            temperature=0.1  # 降低随机性,提高准确性
        )

        answer = response.choices[0].message.content

        # 提取答案中实际使用的引用
        import re
        cited_ids = set(re.findall(r'\[(\d+)\]', answer))
        used_sources = [s for s in sources if str(s["id"]) in cited_ids]

        return {
            "answer": answer,
            "sources": used_sources,
            "all_sources": sources,  # 所有检索到的源
            "model": self.model,
            "retrieved_count": len(retrieved_chunks)
        }

    def verify_answer(self, answer: str, sources: List[dict]) -> Dict:
        """答案验证:检查是否有幻觉"""
        verification_prompt = f"""请判断以下答案是否完全基于提供的文档内容,没有添加文档中不存在的信息。

答案:
{answer}

文档来源:
{sources}

请回答:
1. 是否存在幻觉(答案包含文档中没有的信息)?
2. 如果有,指出哪些部分不可靠

以 JSON 格式返回:{{"has_hallucination": true/false, "unreliable_parts": []}}
"""

        response = self.client.chat.completions.create(
            model="gpt-4o-mini",
            messages=[{"role": "user", "content": verification_prompt}],
            response_format={"type": "json_object"}
        )

        import json
        return json.loads(response.choices[0].message.content)

8.2.5 完整系统实现

FastAPI 服务端

python
from fastapi import FastAPI, UploadFile, File, HTTPException
from pydantic import BaseModel
from typing import List, Optional
import chromadb
from chromadb.utils import embedding_functions

app = FastAPI(title="文档问答助手")

# 初始化向量数据库
chroma_client = chromadb.PersistentClient(path="./chroma_db")
embedding_func = embedding_functions.OpenAIEmbeddingFunction(
    api_key="your-api-key",
    model_name="text-embedding-3-small"
)

collection = chroma_client.get_or_create_collection(
    name="documents",
    embedding_function=embedding_func,
    metadata={"hnsw:space": "cosine"}
)

# 初始化组件
parser = UniversalDocumentParser()
chunker = SmartChunker(chunk_size=800, chunk_overlap=200)
citation_gen = CitationGenerator(model="gpt-4o")

class QueryRequest(BaseModel):
    question: str
    session_id: Optional[str] = None
    top_k: int = 5
    filters: Optional[dict] = None

class QueryResponse(BaseModel):
    answer: str
    sources: List[dict]
    session_id: str

@app.post("/upload")
async def upload_document(file: UploadFile = File(...)):
    """上传并索引文档"""
    import tempfile
    import os

    # 保存临时文件
    with tempfile.NamedTemporaryFile(delete=False, suffix=Path(file.filename).suffix) as tmp:
        content = await file.read()
        tmp.write(content)
        tmp_path = tmp.name

    try:
        # 解析文档
        raw_chunks = parser.parse(tmp_path)

        # 智能分块
        refined_chunks = chunker.chunk_with_overlap(raw_chunks)
        refined_chunks = chunker.add_metadata_summary(refined_chunks)

        # 向量化并存储
        for chunk in refined_chunks:
            collection.add(
                documents=[chunk.content],
                metadatas=[chunk.metadata],
                ids=[chunk.chunk_id]
            )

        return {
            "status": "success",
            "filename": file.filename,
            "chunks_count": len(refined_chunks)
        }

    finally:
        os.unlink(tmp_path)

@app.post("/query", response_model=QueryResponse)
async def query_documents(req: QueryRequest):
    """文档问答"""

    # 检索
    retriever = HybridRetriever(
        vector_store=collection,
        reranker_model="BAAI/bge-reranker-large"
    )

    retrieved = retriever.retrieve(
        query=req.question,
        top_k=20,
        rerank_top_k=req.top_k,
        filters=req.filters
    )

    if not retrieved:
        raise HTTPException(status_code=404, detail="未找到相关文档")

    # 生成答案
    result = citation_gen.generate_with_citations(
        question=req.question,
        retrieved_chunks=retrieved
    )

    return QueryResponse(
        answer=result["answer"],
        sources=result["sources"],
        session_id=req.session_id or "default"
    )

@app.get("/documents")
async def list_documents():
    """列出已上传的文档"""
    # 从 ChromaDB 元数据中提取文档列表
    results = collection.get(include=["metadatas"])

    documents = {}
    for metadata in results["metadatas"]:
        source = metadata.get("source")
        if source:
            documents[source] = documents.get(source, 0) + 1

    return {
        "documents": [
            {"name": name, "chunks": count}
            for name, count in documents.items()
        ]
    }

@app.delete("/documents/{filename}")
async def delete_document(filename: str):
    """删除文档"""
    # 查找并删除相关分块
    results = collection.get(
        where={"source": filename},
        include=["metadatas"]
    )

    if results["ids"]:
        collection.delete(ids=results["ids"])
        return {"status": "success", "deleted_chunks": len(results["ids"])}
    else:
        raise HTTPException(status_code=404, detail="文档不存在")

前端示例(Streamlit)

python
import streamlit as st
import requests
from pathlib import Path

st.title("📚 文档问答助手")

API_URL = "http://localhost:8000"

# 侧边栏:文档上传
with st.sidebar:
    st.header("文档管理")

    uploaded_file = st.file_uploader(
        "上传文档",
        type=["pdf", "docx", "xlsx", "pptx", "md", "html"]
    )

    if uploaded_file and st.button("上传并索引"):
        with st.spinner("解析文档中..."):
            files = {"file": (uploaded_file.name, uploaded_file, uploaded_file.type)}
            response = requests.post(f"{API_URL}/upload", files=files)

            if response.status_code == 200:
                result = response.json()
                st.success(f"✅ 已索引 {result['chunks_count']} 个文档块")
            else:
                st.error("上传失败")

    # 显示已上传文档
    st.subheader("已索引文档")
    docs_response = requests.get(f"{API_URL}/documents")
    if docs_response.status_code == 200:
        docs = docs_response.json()["documents"]
        for doc in docs:
            col1, col2 = st.columns([3, 1])
            col1.text(f"📄 {doc['name']}")
            if col2.button("删除", key=doc['name']):
                requests.delete(f"{API_URL}/documents/{doc['name']}")
                st.rerun()

# 主界面:问答
st.header("💬 提问")

if "messages" not in st.session_state:
    st.session_state.messages = []

# 显示历史对话
for msg in st.session_state.messages:
    with st.chat_message(msg["role"]):
        st.markdown(msg["content"])
        if "sources" in msg:
            with st.expander("📎 引用来源"):
                for source in msg["sources"]:
                    st.caption(
                        f"[{source['id']}] {source['source']} - "
                        f"第 {source.get('page', 'N/A')} 页 "
                        f"(相关性: {source['score']:.2f})"
                    )

# 输入框
if question := st.chat_input("请输入你的问题..."):
    st.session_state.messages.append({"role": "user", "content": question})

    with st.chat_message("user"):
        st.markdown(question)

    with st.chat_message("assistant"):
        with st.spinner("思考中..."):
            response = requests.post(
                f"{API_URL}/query",
                json={"question": question, "top_k": 5}
            )

            if response.status_code == 200:
                result = response.json()
                st.markdown(result["answer"])

                # 显示引用
                with st.expander("📎 引用来源"):
                    for source in result["sources"]:
                        st.caption(
                            f"[{source['id']}] **{source['source']}** - "
                            f"第 {source.get('page', 'N/A')} 页"
                        )

                st.session_state.messages.append({
                    "role": "assistant",
                    "content": result["answer"],
                    "sources": result["sources"]
                })
            else:
                st.error("查询失败,请检查后端服务")

系统优化建议

优化方向具体措施预期效果
检索精度混合检索 + 重排序 + 查询改写召回率提升 15-25%
响应速度向量索引优化(HNSW)、缓存常见问题延迟降低 40%
多文档推理Graph RAG、多跳检索支持跨文档关联问题
答案质量答案验证、引用一致性检查幻觉率降低 30%
成本控制小模型做检索、大模型做生成成本降低 50%

关键指标

指标目标值评估方法
检索召回率≥ 90%标注数据集评估 Top-K 是否包含正确答案
答案准确率≥ 85%人工评估或 LLM-as-Judge
引用准确率≥ 95%验证引用与答案内容的对应关系
平均响应延迟≤ 3s端到端时间监控
幻觉率≤ 5%答案验证机制检测

坚持是一种品格