相信大家都有过提问ai某一个领域的细分内容的时候无法做出准确回答吧
例如 “满天翔” 是福建小彭于晏
当你正常提问ai时得到的答复应该是
如果你有知识库
这里有错误是因为我的项目忘记写utf-8了
好的我们开始吧
这个项目能够从给定的文本文件中构建知识库,并根据用户查询返回最相关的文本块。我们将使用requests
库来处理HTTP请求,numpy
和scikit-learn
中的cosine_similarity
函数来计算文本嵌入之间的相似度。
然后将返回文本区块变为问题一起发送给用户就可以实现一个简单的知识库效果
例如
$newPrompt = "Based on the following information:\n\n{$relevantChunk}\n\nPlease answer this question: {$query}";
准备工作
-
安装依赖:
pip install flask requests numpy scikit-learn
-
获取API密钥:你需要从提供文本嵌入服务的平台获取API URL和API KEY。
步骤1: 创建Flask应用
from flask import Flask, request, jsonify
import requests
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
import logging
app = Flask(__name__)
logging.basicConfig(level=logging.INFO)
步骤2: 配置API信息
API_URL = "你的URL"
API_KEY = "你的apikey"
HEADERS = {
"Authorization": f"Bearer {API_KEY}",
"Content-Type": "application/json"
}
步骤3: 定义辅助函数
load_and_split_text(url)
: 从URL加载文本并分割成块。get_embedding(text)
: 获取文本的嵌入表示。build_knowledge_base(chunks)
: 构建知识库,存储每个块及其嵌入。get_most_relevant_chunk(query, knowledge_base)
: 根据查询找到最相关的文本块。
def load_and_split_text(url):
logging.info(f"Loading text from {url}")
response = requests.get(url)
response.raise_for_status()
text = response.text
chunks = text.split('\n\n')
logging.info(f"Split text into {len(chunks)} chunks")
return chunks
def get_embedding(text):
logging.info("Getting embedding for text")
response = requests.post(f"{API_URL}/v1/embeddings",
headers=HEADERS,
json={
"input": text,
"model": "text-embedding-ada-002"
})
response.raise_for_status()
return response.json()['data'][0]['embedding']
def build_knowledge_base(chunks):
logging.info("Building knowledge base")
knowledge_base = []
for i, chunk in enumerate(chunks):
logging.info(f"Processing chunk {i+1}/{len(chunks)}")
embedding = get_embedding(chunk)
knowledge_base.append((chunk, embedding))
return knowledge_base
def get_most_relevant_chunk(query, knowledge_base):
logging.info("Finding most relevant chunk")
query_embedding = get_embedding(query)
similarities = [cosine_similarity([query_embedding], [chunk[1]])[0][0] for chunk in knowledge_base]
most_relevant_index = np.argmax(similarities)
return knowledge_base[most_relevant_index][0]
步骤4: 定义路由处理请求
@app.route('/get_relevant_chunk', methods=['POST'])
def get_relevant_chunk():
logging.info("Received request for relevant chunk")
data = request.json
if not data:
return jsonify({"error": "No JSON data received"}), 400
file_url = data.get('file_url')
question = data.get('question')
if not file_url or not question:
return jsonify({"error": "Missing file_url or question"}), 400
try:
chunks = load_and_split_text(file_url)
knowledge_base = build_knowledge_base(chunks)
relevant_chunk = get_most_relevant_chunk(question, knowledge_base)
return jsonify({"relevant_chunk": relevant_chunk})
except requests.RequestException as e:
logging.error(f"Request error: {str(e)}")
return jsonify({"error": f"Request error: {str(e)}"}), 500
except Exception as e:
logging.error(f"Unexpected error: {str(e)}")
return jsonify({"error": f"Unexpected error: {str(e)}"}), 500
if __name__ == '__main__':
app.run(host='0.0.0.0', debug=True, port=5003)
完整的代码
from flask import Flask, request, jsonify
import requests
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
import logging
app = Flask(__name__)
logging.basicConfig(level=logging.INFO)
API_URL = "你的URL"
API_KEY = "你的apikey"
HEADERS = {
"Authorization": f"Bearer {API_KEY}",
"Content-Type": "application/json"
}
def load_and_split_text(url):
logging.info(f"Loading text from {url}")
response = requests.get(url)
response.raise_for_status()
text = response.text
chunks = text.split('\n\n')
logging.info(f"Split text into {len(chunks)} chunks")
return chunks
def get_embedding(text):
logging.info("Getting embedding for text")
response = requests.post(f"{API_URL}/v1/embeddings",
headers=HEADERS,
json={
"input": text,
"model": "text-embedding-ada-002"
})
response.raise_for_status()
return response.json()['data'][0]['embedding']
def build_knowledge_base(chunks):
logging.info("Building knowledge base")
knowledge_base = []
for i, chunk in enumerate(chunks):
logging.info(f"Processing chunk {i+1}/{len(chunks)}")
embedding = get_embedding(chunk)
knowledge_base.append((chunk, embedding))
return knowledge_base
def get_most_relevant_chunk(query, knowledge_base):
logging.info("Finding most relevant chunk")
query_embedding = get_embedding(query)
similarities = [cosine_similarity([query_embedding], [chunk[1]])[0][0] for chunk in knowledge_base]
most_relevant_index = np.argmax(similarities)
return knowledge_base[most_relevant_index][0]
@app.route('/get_relevant_chunk', methods=['POST'])
def get_relevant_chunk():
logging.info("Received request for relevant chunk")
data = request.json
if not data:
return jsonify({"error": "No JSON data received"}), 400
file_url = data.get('file_url')
question = data.get('question')
if not file_url or not question:
return jsonify({"error": "Missing file_url or question"}), 400
try:
chunks = load_and_split_text(file_url)
knowledge_base = build_knowledge_base(chunks)
relevant_chunk = get_most_relevant_chunk(question, knowledge_base)
return jsonify({"relevant_chunk": relevant_chunk})
except requests.RequestException as e:
logging.error(f"Request error: {str(e)}")
return jsonify({"error": f"Request error: {str(e)}"}), 500
except Exception as e:
logging.error(f"Unexpected error: {str(e)}")
return jsonify({"error": f"Unexpected error: {str(e)}"}), 500
if __name__ == '__main__':
logging.info("Starting Flask app...")
app.run(host='0.0.0.0', debug=True, port=5003
运行与测试
运行你的Flask应用,并使用Postman或类似工具发送POST请求到http://localhost:5003/get_relevant_chunk
,确保在请求体中包含JSON数据,包括文件URL和问题。