8.2 知识库问答系统(RAG 实战)
完整实现:文档上传 → 多格式解析 → 智能分块 → 向量化 → 混合检索 + Reranker → 引用溯源生成。
难度:⭐⭐⭐⭐ | 预计时长:3-4 周
文档问答助手是 RAG 技术的典型应用场景,核心挑战是:多格式文档解析、智能分块策略、精准检索、答案引用溯源、处理跨文档关联问题。本节构建一个支持多格式文档的企业知识库问答系统。
系统架构
┌──────────────────────────┐
│ 用户提问 │
└────────────┬─────────────┘
│
┌─────────────────┼─────────────────┐
│ │ │
┌───────▼──────┐ ┌──────▼──────┐ ┌──────▼──────┐
│ 问题重写 │ │ 查询扩展 │ │ 意图识别 │
│ (Query │ │ (关键词提取)│ │ │
│ Rewriting) │ │ │ │ │
└───────┬──────┘ └──────┬──────┘ └──────┬──────┘
└─────────────────┼─────────────────┘
│
┌────────▼─────────┐
│ 混合检索引擎 │
│ - 向量检索 │
│ - BM25 全文检索 │
│ - 元数据过滤 │
└────────┬─────────┘
│
┌────────▼─────────┐
│ 重排序(Rerank) │
│ - Cross-Encoder │
│ - Cohere/BGE │
└────────┬─────────┘
│
┌────────▼─────────┐
│ 上下文构建 │
│ - 引用标注 │
│ - 去重与合并 │
└────────┬─────────┘
│
┌────────▼─────────┐
│ LLM 生成答案 │
│ + 引用标注 │
└────────┬─────────┘
│
┌────────▼─────────┐
│ 后处理与溯源 │
│ - 答案验证 │
│ - 引用格式化 │
└──────────────────┘8.2.1 多格式文档解析
核心库选择
| 文档类型 | 推荐库 | 优点 | 注意事项 |
|---|---|---|---|
pypdf, pdfplumber, PyMuPDF | 支持文本+表格+图片提取 | 扫描件 PDF 需 OCR | |
| Word | python-docx, docx2txt | 保留格式与结构 | docx2txt 更轻量 |
| Excel | openpyxl, pandas | 表格结构化处理 | 大文件用 read_only=True |
| PPT | python-pptx | 提取幻灯片文本与备注 | 图片需单独处理 |
| Markdown | mistune, markdown-it-py | 解析为 AST 树 | 保留代码块格式 |
| HTML | BeautifulSoup4, trafilatura | 网页正文提取 | trafilatura 去噪能力强 |
| 图片 OCR | PaddleOCR, EasyOCR | 中英文混合识别 | GPU 加速效果显著 |
统一文档解析器
python
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional
import mimetypes
@dataclass
class DocumentChunk:
"""文档分块单元"""
content: str # 文本内容
metadata: dict # 元数据(来源、页码、标题等)
chunk_id: str # 全局唯一 ID
embedding: Optional[List[float]] = None
class UniversalDocumentParser:
"""多格式文档解析器"""
def __init__(self):
self.parsers = {
'application/pdf': self._parse_pdf,
'application/vnd.openxmlformats-officedocument.wordprocessingml.document': self._parse_docx,
'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet': self._parse_xlsx,
'application/vnd.openxmlformats-officedocument.presentationml.presentation': self._parse_pptx,
'text/markdown': self._parse_markdown,
'text/html': self._parse_html,
'image/png': self._parse_image_ocr,
'image/jpeg': self._parse_image_ocr,
}
def parse(self, file_path: str) -> List[DocumentChunk]:
"""解析文档并返回分块列表"""
mime_type, _ = mimetypes.guess_type(file_path)
if mime_type not in self.parsers:
raise ValueError(f"不支持的文档类型: {mime_type}")
parser_func = self.parsers[mime_type]
return parser_func(file_path)
def _parse_pdf(self, file_path: str) -> List[DocumentChunk]:
"""解析 PDF(支持文本+表格+图片)"""
import pdfplumber
from PIL import Image
import io
chunks = []
with pdfplumber.open(file_path) as pdf:
for page_num, page in enumerate(pdf.pages, start=1):
# 提取文本
text = page.extract_text() or ""
# 提取表格
tables = page.extract_tables()
for table in tables:
table_text = "\n".join([" | ".join(row) for row in table if row])
text += f"\n\n[表格]\n{table_text}\n"
# 提取图片(如果需要 OCR)
for img_index, img_obj in enumerate(page.images):
try:
# 这里可以添加图片 OCR 逻辑
text += f"\n[图片 {img_index + 1}]\n"
except Exception:
pass
if text.strip():
chunks.append(DocumentChunk(
content=text.strip(),
metadata={
"source": Path(file_path).name,
"page": page_num,
"type": "pdf"
},
chunk_id=f"{Path(file_path).stem}_page_{page_num}"
))
return chunks
def _parse_docx(self, file_path: str) -> List[DocumentChunk]:
"""解析 Word 文档"""
import docx
doc = docx.Document(file_path)
chunks = []
current_section = ""
section_content = []
for para in doc.paragraphs:
text = para.text.strip()
if not text:
continue
# 检测标题(作为分块边界)
if para.style.name.startswith('Heading'):
if section_content:
chunks.append(DocumentChunk(
content="\n".join(section_content),
metadata={
"source": Path(file_path).name,
"section": current_section,
"type": "docx"
},
chunk_id=f"{Path(file_path).stem}_{len(chunks)}"
))
section_content = []
current_section = text
section_content.append(text)
# 添加最后一个章节
if section_content:
chunks.append(DocumentChunk(
content="\n".join(section_content),
metadata={
"source": Path(file_path).name,
"section": current_section,
"type": "docx"
},
chunk_id=f"{Path(file_path).stem}_{len(chunks)}"
))
# 解析表格
for table_idx, table in enumerate(doc.tables):
table_text = []
for row in table.rows:
row_text = " | ".join([cell.text.strip() for cell in row.cells])
table_text.append(row_text)
chunks.append(DocumentChunk(
content=f"[表格 {table_idx + 1}]\n" + "\n".join(table_text),
metadata={
"source": Path(file_path).name,
"type": "docx_table",
"table_index": table_idx + 1
},
chunk_id=f"{Path(file_path).stem}_table_{table_idx}"
))
return chunks
def _parse_xlsx(self, file_path: str) -> List[DocumentChunk]:
"""解析 Excel 表格"""
import pandas as pd
chunks = []
excel_file = pd.ExcelFile(file_path)
for sheet_name in excel_file.sheet_names:
df = pd.read_excel(file_path, sheet_name=sheet_name)
# 转换为 Markdown 表格格式
table_md = df.to_markdown(index=False)
chunks.append(DocumentChunk(
content=f"# {sheet_name}\n\n{table_md}",
metadata={
"source": Path(file_path).name,
"sheet": sheet_name,
"type": "xlsx",
"rows": len(df),
"columns": len(df.columns)
},
chunk_id=f"{Path(file_path).stem}_{sheet_name}"
))
return chunks
def _parse_pptx(self, file_path: str) -> List[DocumentChunk]:
"""解析 PowerPoint"""
from pptx import Presentation
prs = Presentation(file_path)
chunks = []
for slide_num, slide in enumerate(prs.slides, start=1):
text_parts = []
# 提取标题
if slide.shapes.title:
text_parts.append(f"# {slide.shapes.title.text}")
# 提取文本框内容
for shape in slide.shapes:
if hasattr(shape, "text") and shape.text.strip():
text_parts.append(shape.text.strip())
# 提取备注
if slide.has_notes_slide:
notes_text = slide.notes_slide.notes_text_frame.text.strip()
if notes_text:
text_parts.append(f"\n[演讲备注]\n{notes_text}")
if text_parts:
chunks.append(DocumentChunk(
content="\n\n".join(text_parts),
metadata={
"source": Path(file_path).name,
"slide": slide_num,
"type": "pptx"
},
chunk_id=f"{Path(file_path).stem}_slide_{slide_num}"
))
return chunks
def _parse_markdown(self, file_path: str) -> List[DocumentChunk]:
"""解析 Markdown(按标题分块)"""
import re
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
# 按一级或二级标题分块
sections = re.split(r'\n(?=#{1,2}\s)', content)
chunks = []
for idx, section in enumerate(sections):
if section.strip():
# 提取标题作为元数据
title_match = re.match(r'^(#{1,2})\s+(.+)$', section.split('\n')[0])
title = title_match.group(2) if title_match else f"Section {idx + 1}"
chunks.append(DocumentChunk(
content=section.strip(),
metadata={
"source": Path(file_path).name,
"section": title,
"type": "markdown"
},
chunk_id=f"{Path(file_path).stem}_{idx}"
))
return chunks
def _parse_html(self, file_path: str) -> List[DocumentChunk]:
"""解析 HTML(提取正文)"""
import trafilatura
with open(file_path, 'r', encoding='utf-8') as f:
html = f.read()
# 使用 trafilatura 提取正文(自动去除导航、广告等)
text = trafilatura.extract(html, include_tables=True, include_comments=False)
if text:
return [DocumentChunk(
content=text,
metadata={
"source": Path(file_path).name,
"type": "html"
},
chunk_id=Path(file_path).stem
)]
return []
def _parse_image_ocr(self, file_path: str) -> List[DocumentChunk]:
"""解析图片(OCR 文字识别)"""
from paddleocr import PaddleOCR
ocr = PaddleOCR(use_angle_cls=True, lang='ch', use_gpu=False)
result = ocr.ocr(file_path, cls=True)
if not result or not result[0]:
return []
# 提取识别的文本
text_lines = [line[1][0] for line in result[0]]
text = "\n".join(text_lines)
return [DocumentChunk(
content=text,
metadata={
"source": Path(file_path).name,
"type": "image_ocr",
"confidence": sum([line[1][1] for line in result[0]]) / len(result[0])
},
chunk_id=Path(file_path).stem
)]8.2.2 智能分块策略
分块策略对比
| 策略 | 适用场景 | 优点 | 缺点 |
|---|---|---|---|
| 固定长度 | 通用文本 | 简单高效 | 可能切断语义 |
| 语义分块 | 长文档 | 保持语义完整 | 计算成本高 |
| 结构化分块 | 有层级的文档(Markdown、Word) | 保留结构信息 | 依赖文档格式 |
| 滑动窗口 | 需要上下文重叠 | 检索召回率高 | 存储成本增加 |
高级分块实现
python
from typing import List
from langchain.text_splitter import RecursiveCharacterTextSplitter
import tiktoken
class SmartChunker:
"""智能文档分块器"""
def __init__(
self,
chunk_size: int = 800, # 目标块大小(tokens)
chunk_overlap: int = 200, # 重叠 tokens 数
model_name: str = "gpt-4o"
):
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
self.encoding = tiktoken.encoding_for_model(model_name)
def chunk_by_semantic(self, chunks: List[DocumentChunk]) -> List[DocumentChunk]:
"""语义分块:利用嵌入模型识别语义边界"""
from langchain_experimental.text_splitter import SemanticChunker
from langchain_openai import OpenAIEmbeddings
embeddings = OpenAIEmbeddings(model="text-embedding-3-small")
semantic_chunker = SemanticChunker(
embeddings=embeddings,
breakpoint_threshold_type="percentile" # 或 "standard_deviation"
)
refined_chunks = []
for chunk in chunks:
# 对长文本进行语义分块
if self._count_tokens(chunk.content) > self.chunk_size * 1.5:
sub_docs = semantic_chunker.create_documents([chunk.content])
for idx, sub_doc in enumerate(sub_docs):
refined_chunks.append(DocumentChunk(
content=sub_doc.page_content,
metadata={**chunk.metadata, "sub_chunk": idx + 1},
chunk_id=f"{chunk.chunk_id}_semantic_{idx}"
))
else:
refined_chunks.append(chunk)
return refined_chunks
def chunk_with_overlap(self, chunks: List[DocumentChunk]) -> List[DocumentChunk]:
"""滑动窗口分块(增加召回率)"""
splitter = RecursiveCharacterTextSplitter(
chunk_size=self.chunk_size,
chunk_overlap=self.chunk_overlap,
length_function=self._count_tokens,
separators=["\n\n", "\n", "。", "!", "?", ".", "!", "?", " ", ""]
)
refined_chunks = []
for chunk in chunks:
if self._count_tokens(chunk.content) > self.chunk_size:
texts = splitter.split_text(chunk.content)
for idx, text in enumerate(texts):
refined_chunks.append(DocumentChunk(
content=text,
metadata={**chunk.metadata, "sub_chunk": idx + 1},
chunk_id=f"{chunk.chunk_id}_part_{idx}"
))
else:
refined_chunks.append(chunk)
return refined_chunks
def _count_tokens(self, text: str) -> int:
"""计算文本 token 数"""
return len(self.encoding.encode(text))
def add_metadata_summary(self, chunks: List[DocumentChunk]) -> List[DocumentChunk]:
"""为每个分块生成元数据摘要(提升检索精度)"""
from openai import OpenAI
client = OpenAI()
for chunk in chunks:
# 为长文本块生成摘要(用于元数据)
if self._count_tokens(chunk.content) > 500:
try:
response = client.chat.completions.create(
model="gpt-4o-mini",
messages=[
{"role": "system", "content": "用一句话概括下面文本的核心内容(30字以内):"},
{"role": "user", "content": chunk.content[:1000]}
],
max_tokens=100
)
chunk.metadata["summary"] = response.choices[0].message.content.strip()
except Exception as e:
print(f"生成摘要失败: {e}")
return chunks8.2.3 混合检索 + 重排序
检索管道架构
python
from typing import List, Tuple
from dataclasses import dataclass
import numpy as np
@dataclass
class RetrievalResult:
chunk: DocumentChunk
score: float
retrieval_method: str # "vector", "bm25", "hybrid"
class HybridRetriever:
"""混合检索器:向量检索 + BM25 + 重排序"""
def __init__(
self,
vector_store, # 向量数据库(ChromaDB/Pinecone/Qdrant)
bm25_index=None, # BM25 索引
reranker_model: str = "BAAI/bge-reranker-large"
):
self.vector_store = vector_store
self.bm25_index = bm25_index
# 加载重排序模型
from sentence_transformers import CrossEncoder
self.reranker = CrossEncoder(reranker_model)
def retrieve(
self,
query: str,
top_k: int = 20,
rerank_top_k: int = 5,
vector_weight: float = 0.7,
use_rerank: bool = True,
filters: dict = None
) -> List[RetrievalResult]:
"""
混合检索流程:
1. 向量检索 + BM25 并行检索
2. RRF 融合排序
3. 重排序(Reranker)
"""
# 1. 向量检索
vector_results = self._vector_search(query, top_k, filters)
# 2. BM25 全文检索
bm25_results = self._bm25_search(query, top_k) if self.bm25_index else []
# 3. RRF 融合(Reciprocal Rank Fusion)
hybrid_results = self._rrf_fusion(
vector_results,
bm25_results,
k=60, # RRF 超参数
vector_weight=vector_weight
)
# 4. 重排序
if use_rerank and len(hybrid_results) > rerank_top_k:
hybrid_results = self._rerank(query, hybrid_results, rerank_top_k)
return hybrid_results[:rerank_top_k]
def _vector_search(
self,
query: str,
top_k: int,
filters: dict = None
) -> List[RetrievalResult]:
"""向量检索"""
results = self.vector_store.similarity_search_with_score(
query,
k=top_k,
filter=filters
)
return [
RetrievalResult(
chunk=DocumentChunk(
content=doc.page_content,
metadata=doc.metadata,
chunk_id=doc.metadata.get("chunk_id", "")
),
score=float(score),
retrieval_method="vector"
)
for doc, score in results
]
def _bm25_search(self, query: str, top_k: int) -> List[RetrievalResult]:
"""BM25 全文检索"""
from rank_bm25 import BM25Okapi
if not self.bm25_index:
return []
tokenized_query = query.split()
scores = self.bm25_index.get_scores(tokenized_query)
top_indices = np.argsort(scores)[::-1][:top_k]
return [
RetrievalResult(
chunk=self.bm25_index.corpus[idx],
score=float(scores[idx]),
retrieval_method="bm25"
)
for idx in top_indices
if scores[idx] > 0
]
def _rrf_fusion(
self,
vector_results: List[RetrievalResult],
bm25_results: List[RetrievalResult],
k: int = 60,
vector_weight: float = 0.7
) -> List[RetrievalResult]:
"""
RRF 融合:fused_score = Σ (1 / (k + rank_i))
k 通常取 60,是经验值
"""
score_map = {}
# 向量检索结果
for rank, result in enumerate(vector_results):
chunk_id = result.chunk.chunk_id
score_map[chunk_id] = score_map.get(chunk_id, 0) + \
vector_weight * (1 / (k + rank + 1))
# BM25 结果
for rank, result in enumerate(bm25_results):
chunk_id = result.chunk.chunk_id
score_map[chunk_id] = score_map.get(chunk_id, 0) + \
(1 - vector_weight) * (1 / (k + rank + 1))
# 合并结果并排序
all_chunks = {r.chunk.chunk_id: r.chunk for r in vector_results + bm25_results}
fused_results = [
RetrievalResult(
chunk=all_chunks[chunk_id],
score=score,
retrieval_method="hybrid"
)
for chunk_id, score in score_map.items()
]
return sorted(fused_results, key=lambda x: x.score, reverse=True)
def _rerank(
self,
query: str,
results: List[RetrievalResult],
top_k: int
) -> List[RetrievalResult]:
"""重排序:使用 Cross-Encoder 重新打分"""
pairs = [[query, r.chunk.content] for r in results]
rerank_scores = self.reranker.predict(pairs)
# 更新分数
for idx, score in enumerate(rerank_scores):
results[idx].score = float(score)
results[idx].retrieval_method = "reranked"
return sorted(results, key=lambda x: x.score, reverse=True)[:top_k]8.2.4 答案生成与引用溯源
带引用的答案生成
python
from typing import List, Dict
from openai import OpenAI
class CitationGenerator:
"""引用溯源的答案生成器"""
def __init__(self, model: str = "gpt-4o"):
self.client = OpenAI()
self.model = model
def generate_with_citations(
self,
question: str,
retrieved_chunks: List[RetrievalResult],
conversation_history: List[dict] = None
) -> Dict:
"""
生成答案并标注引用来源
返回:{answer, citations, sources}
"""
# 构建上下文(带引用编号)
context_parts = []
sources = []
for idx, result in enumerate(retrieved_chunks, start=1):
chunk = result.chunk
citation_id = f"[{idx}]"
context_parts.append(
f"{citation_id} {chunk.content}\n"
f"来源: {chunk.metadata.get('source', 'Unknown')}, "
f"页码: {chunk.metadata.get('page', 'N/A')}"
)
sources.append({
"id": idx,
"source": chunk.metadata.get("source", ""),
"page": chunk.metadata.get("page"),
"section": chunk.metadata.get("section"),
"score": result.score
})
context = "\n\n".join(context_parts)
# 构建 Prompt
system_prompt = """你是一个专业的文档问答助手。请基于提供的文档片段回答用户问题。
**重要规则**:
1. 只使用提供的文档内容回答,不要编造信息
2. 在答案中使用 [数字] 标注引用来源(如:根据文档 [1],公司成立于...)
3. 如果文档中没有相关信息,明确告知"文档中未找到相关信息"
4. 答案要准确、简洁、结构化
文档内容:
{context}
"""
messages = [
{"role": "system", "content": system_prompt.format(context=context)}
]
# 添加历史对话(多轮问答)
if conversation_history:
messages.extend(conversation_history)
messages.append({"role": "user", "content": question})
# 调用 LLM
response = self.client.chat.completions.create(
model=self.model,
messages=messages,
temperature=0.1 # 降低随机性,提高准确性
)
answer = response.choices[0].message.content
# 提取答案中实际使用的引用
import re
cited_ids = set(re.findall(r'\[(\d+)\]', answer))
used_sources = [s for s in sources if str(s["id"]) in cited_ids]
return {
"answer": answer,
"sources": used_sources,
"all_sources": sources, # 所有检索到的源
"model": self.model,
"retrieved_count": len(retrieved_chunks)
}
def verify_answer(self, answer: str, sources: List[dict]) -> Dict:
"""答案验证:检查是否有幻觉"""
verification_prompt = f"""请判断以下答案是否完全基于提供的文档内容,没有添加文档中不存在的信息。
答案:
{answer}
文档来源:
{sources}
请回答:
1. 是否存在幻觉(答案包含文档中没有的信息)?
2. 如果有,指出哪些部分不可靠
以 JSON 格式返回:{{"has_hallucination": true/false, "unreliable_parts": []}}
"""
response = self.client.chat.completions.create(
model="gpt-4o-mini",
messages=[{"role": "user", "content": verification_prompt}],
response_format={"type": "json_object"}
)
import json
return json.loads(response.choices[0].message.content)8.2.5 完整系统实现
FastAPI 服务端
python
from fastapi import FastAPI, UploadFile, File, HTTPException
from pydantic import BaseModel
from typing import List, Optional
import chromadb
from chromadb.utils import embedding_functions
app = FastAPI(title="文档问答助手")
# 初始化向量数据库
chroma_client = chromadb.PersistentClient(path="./chroma_db")
embedding_func = embedding_functions.OpenAIEmbeddingFunction(
api_key="your-api-key",
model_name="text-embedding-3-small"
)
collection = chroma_client.get_or_create_collection(
name="documents",
embedding_function=embedding_func,
metadata={"hnsw:space": "cosine"}
)
# 初始化组件
parser = UniversalDocumentParser()
chunker = SmartChunker(chunk_size=800, chunk_overlap=200)
citation_gen = CitationGenerator(model="gpt-4o")
class QueryRequest(BaseModel):
question: str
session_id: Optional[str] = None
top_k: int = 5
filters: Optional[dict] = None
class QueryResponse(BaseModel):
answer: str
sources: List[dict]
session_id: str
@app.post("/upload")
async def upload_document(file: UploadFile = File(...)):
"""上传并索引文档"""
import tempfile
import os
# 保存临时文件
with tempfile.NamedTemporaryFile(delete=False, suffix=Path(file.filename).suffix) as tmp:
content = await file.read()
tmp.write(content)
tmp_path = tmp.name
try:
# 解析文档
raw_chunks = parser.parse(tmp_path)
# 智能分块
refined_chunks = chunker.chunk_with_overlap(raw_chunks)
refined_chunks = chunker.add_metadata_summary(refined_chunks)
# 向量化并存储
for chunk in refined_chunks:
collection.add(
documents=[chunk.content],
metadatas=[chunk.metadata],
ids=[chunk.chunk_id]
)
return {
"status": "success",
"filename": file.filename,
"chunks_count": len(refined_chunks)
}
finally:
os.unlink(tmp_path)
@app.post("/query", response_model=QueryResponse)
async def query_documents(req: QueryRequest):
"""文档问答"""
# 检索
retriever = HybridRetriever(
vector_store=collection,
reranker_model="BAAI/bge-reranker-large"
)
retrieved = retriever.retrieve(
query=req.question,
top_k=20,
rerank_top_k=req.top_k,
filters=req.filters
)
if not retrieved:
raise HTTPException(status_code=404, detail="未找到相关文档")
# 生成答案
result = citation_gen.generate_with_citations(
question=req.question,
retrieved_chunks=retrieved
)
return QueryResponse(
answer=result["answer"],
sources=result["sources"],
session_id=req.session_id or "default"
)
@app.get("/documents")
async def list_documents():
"""列出已上传的文档"""
# 从 ChromaDB 元数据中提取文档列表
results = collection.get(include=["metadatas"])
documents = {}
for metadata in results["metadatas"]:
source = metadata.get("source")
if source:
documents[source] = documents.get(source, 0) + 1
return {
"documents": [
{"name": name, "chunks": count}
for name, count in documents.items()
]
}
@app.delete("/documents/{filename}")
async def delete_document(filename: str):
"""删除文档"""
# 查找并删除相关分块
results = collection.get(
where={"source": filename},
include=["metadatas"]
)
if results["ids"]:
collection.delete(ids=results["ids"])
return {"status": "success", "deleted_chunks": len(results["ids"])}
else:
raise HTTPException(status_code=404, detail="文档不存在")前端示例(Streamlit)
python
import streamlit as st
import requests
from pathlib import Path
st.title("📚 文档问答助手")
API_URL = "http://localhost:8000"
# 侧边栏:文档上传
with st.sidebar:
st.header("文档管理")
uploaded_file = st.file_uploader(
"上传文档",
type=["pdf", "docx", "xlsx", "pptx", "md", "html"]
)
if uploaded_file and st.button("上传并索引"):
with st.spinner("解析文档中..."):
files = {"file": (uploaded_file.name, uploaded_file, uploaded_file.type)}
response = requests.post(f"{API_URL}/upload", files=files)
if response.status_code == 200:
result = response.json()
st.success(f"✅ 已索引 {result['chunks_count']} 个文档块")
else:
st.error("上传失败")
# 显示已上传文档
st.subheader("已索引文档")
docs_response = requests.get(f"{API_URL}/documents")
if docs_response.status_code == 200:
docs = docs_response.json()["documents"]
for doc in docs:
col1, col2 = st.columns([3, 1])
col1.text(f"📄 {doc['name']}")
if col2.button("删除", key=doc['name']):
requests.delete(f"{API_URL}/documents/{doc['name']}")
st.rerun()
# 主界面:问答
st.header("💬 提问")
if "messages" not in st.session_state:
st.session_state.messages = []
# 显示历史对话
for msg in st.session_state.messages:
with st.chat_message(msg["role"]):
st.markdown(msg["content"])
if "sources" in msg:
with st.expander("📎 引用来源"):
for source in msg["sources"]:
st.caption(
f"[{source['id']}] {source['source']} - "
f"第 {source.get('page', 'N/A')} 页 "
f"(相关性: {source['score']:.2f})"
)
# 输入框
if question := st.chat_input("请输入你的问题..."):
st.session_state.messages.append({"role": "user", "content": question})
with st.chat_message("user"):
st.markdown(question)
with st.chat_message("assistant"):
with st.spinner("思考中..."):
response = requests.post(
f"{API_URL}/query",
json={"question": question, "top_k": 5}
)
if response.status_code == 200:
result = response.json()
st.markdown(result["answer"])
# 显示引用
with st.expander("📎 引用来源"):
for source in result["sources"]:
st.caption(
f"[{source['id']}] **{source['source']}** - "
f"第 {source.get('page', 'N/A')} 页"
)
st.session_state.messages.append({
"role": "assistant",
"content": result["answer"],
"sources": result["sources"]
})
else:
st.error("查询失败,请检查后端服务")系统优化建议
| 优化方向 | 具体措施 | 预期效果 |
|---|---|---|
| 检索精度 | 混合检索 + 重排序 + 查询改写 | 召回率提升 15-25% |
| 响应速度 | 向量索引优化(HNSW)、缓存常见问题 | 延迟降低 40% |
| 多文档推理 | Graph RAG、多跳检索 | 支持跨文档关联问题 |
| 答案质量 | 答案验证、引用一致性检查 | 幻觉率降低 30% |
| 成本控制 | 小模型做检索、大模型做生成 | 成本降低 50% |
关键指标
| 指标 | 目标值 | 评估方法 |
|---|---|---|
| 检索召回率 | ≥ 90% | 标注数据集评估 Top-K 是否包含正确答案 |
| 答案准确率 | ≥ 85% | 人工评估或 LLM-as-Judge |
| 引用准确率 | ≥ 95% | 验证引用与答案内容的对应关系 |
| 平均响应延迟 | ≤ 3s | 端到端时间监控 |
| 幻觉率 | ≤ 5% | 答案验证机制检测 |