用py写一个简单的ai知识库

相信大家都有过提问ai某一个领域的细分内容的时候无法做出准确回答吧

例如 “满天翔” 是福建小彭于晏

当你正常提问ai时得到的答复应该是

如果你有知识库


这里有错误是因为我的项目忘记写utf-8了

好的我们开始吧

这个项目能够从给定的文本文件中构建知识库,并根据用户查询返回最相关的文本块。我们将使用requests库来处理HTTP请求,numpyscikit-learn中的cosine_similarity函数来计算文本嵌入之间的相似度。
然后将返回文本区块变为问题一起发送给用户就可以实现一个简单的知识库效果

例如

 $newPrompt = "Based on the following information:\n\n{$relevantChunk}\n\nPlease answer this question: {$query}";

准备工作

  1. 安装依赖

    pip install flask requests numpy scikit-learn
    
  2. 获取API密钥:你需要从提供文本嵌入服务的平台获取API URL和API KEY。

步骤1: 创建Flask应用

from flask import Flask, request, jsonify
import requests
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
import logging

app = Flask(__name__)
logging.basicConfig(level=logging.INFO)

步骤2: 配置API信息

API_URL = "你的URL"
API_KEY = "你的apikey"

HEADERS = {
    "Authorization": f"Bearer {API_KEY}",
    "Content-Type": "application/json"
}

步骤3: 定义辅助函数

  • load_and_split_text(url): 从URL加载文本并分割成块。
  • get_embedding(text): 获取文本的嵌入表示。
  • build_knowledge_base(chunks): 构建知识库,存储每个块及其嵌入。
  • get_most_relevant_chunk(query, knowledge_base): 根据查询找到最相关的文本块。
def load_and_split_text(url):
    logging.info(f"Loading text from {url}")
    response = requests.get(url)
    response.raise_for_status()
    text = response.text
    chunks = text.split('\n\n')
    logging.info(f"Split text into {len(chunks)} chunks")
    return chunks


def get_embedding(text):
    logging.info("Getting embedding for text")
    response = requests.post(f"{API_URL}/v1/embeddings", 
                             headers=HEADERS,
                             json={
                                 "input": text,
                                 "model": "text-embedding-ada-002"
                             })
    response.raise_for_status()
    return response.json()['data'][0]['embedding']


def build_knowledge_base(chunks):
    logging.info("Building knowledge base")
    knowledge_base = []
    for i, chunk in enumerate(chunks):
        logging.info(f"Processing chunk {i+1}/{len(chunks)}")
        embedding = get_embedding(chunk)
        knowledge_base.append((chunk, embedding))
    return knowledge_base


def get_most_relevant_chunk(query, knowledge_base):
    logging.info("Finding most relevant chunk")
    query_embedding = get_embedding(query)
    similarities = [cosine_similarity([query_embedding], [chunk[1]])[0][0] for chunk in knowledge_base]
    most_relevant_index = np.argmax(similarities)
    return knowledge_base[most_relevant_index][0]

步骤4: 定义路由处理请求

@app.route('/get_relevant_chunk', methods=['POST'])
def get_relevant_chunk():
    logging.info("Received request for relevant chunk")
    data = request.json
    if not data:
        return jsonify({"error": "No JSON data received"}), 400
    
    file_url = data.get('file_url')
    question = data.get('question')


    if not file_url or not question:
        return jsonify({"error": "Missing file_url or question"}), 400


    try:
        chunks = load_and_split_text(file_url)
        knowledge_base = build_knowledge_base(chunks)
        relevant_chunk = get_most_relevant_chunk(question, knowledge_base)
        return jsonify({"relevant_chunk": relevant_chunk})
    except requests.RequestException as e:
        logging.error(f"Request error: {str(e)}")
        return jsonify({"error": f"Request error: {str(e)}"}), 500
    except Exception as e:
        logging.error(f"Unexpected error: {str(e)}")
        return jsonify({"error": f"Unexpected error: {str(e)}"}), 500

if __name__ == '__main__':
   app.run(host='0.0.0.0', debug=True, port=5003)

完整的代码

from flask import Flask, request, jsonify
import requests
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
import logging


app = Flask(__name__)
logging.basicConfig(level=logging.INFO)


API_URL = "你的URL"
API_KEY = "你的apikey"


HEADERS = {
    "Authorization": f"Bearer {API_KEY}",
    "Content-Type": "application/json"
}


def load_and_split_text(url):
    logging.info(f"Loading text from {url}")
    response = requests.get(url)
    response.raise_for_status()
    text = response.text
    chunks = text.split('\n\n')
    logging.info(f"Split text into {len(chunks)} chunks")
    return chunks


def get_embedding(text):
    logging.info("Getting embedding for text")
    response = requests.post(f"{API_URL}/v1/embeddings", 
                             headers=HEADERS,
                             json={
                                 "input": text,
                                 "model": "text-embedding-ada-002"
                             })
    response.raise_for_status()
    return response.json()['data'][0]['embedding']


def build_knowledge_base(chunks):
    logging.info("Building knowledge base")
    knowledge_base = []
    for i, chunk in enumerate(chunks):
        logging.info(f"Processing chunk {i+1}/{len(chunks)}")
        embedding = get_embedding(chunk)
        knowledge_base.append((chunk, embedding))
    return knowledge_base


def get_most_relevant_chunk(query, knowledge_base):
    logging.info("Finding most relevant chunk")
    query_embedding = get_embedding(query)
    similarities = [cosine_similarity([query_embedding], [chunk[1]])[0][0] for chunk in knowledge_base]
    most_relevant_index = np.argmax(similarities)
    return knowledge_base[most_relevant_index][0]


@app.route('/get_relevant_chunk', methods=['POST'])
def get_relevant_chunk():
    logging.info("Received request for relevant chunk")
    data = request.json
    if not data:
        return jsonify({"error": "No JSON data received"}), 400
    
    file_url = data.get('file_url')
    question = data.get('question')


    if not file_url or not question:
        return jsonify({"error": "Missing file_url or question"}), 400


    try:
        chunks = load_and_split_text(file_url)
        knowledge_base = build_knowledge_base(chunks)
        relevant_chunk = get_most_relevant_chunk(question, knowledge_base)
        return jsonify({"relevant_chunk": relevant_chunk})
    except requests.RequestException as e:
        logging.error(f"Request error: {str(e)}")
        return jsonify({"error": f"Request error: {str(e)}"}), 500
    except Exception as e:
        logging.error(f"Unexpected error: {str(e)}")
        return jsonify({"error": f"Unexpected error: {str(e)}"}), 500


if __name__ == '__main__':
    logging.info("Starting Flask app...")
    app.run(host='0.0.0.0', debug=True, port=5003

运行与测试

运行你的Flask应用,并使用Postman或类似工具发送POST请求到http://localhost:5003/get_relevant_chunk,确保在请求体中包含JSON数据,包括文件URL和问题。

当然现在的大部分向量模型都只有8k的输入,如果你的文件比较大的话可以分割文件然后把每个分割好的文件进行检索然后把所有检索的最高值再继续检索(以此类推的树状检索应该可以解决)

该帖仅供参考,因为我没有看过实际的知识库的原理,所以实现起来可能点笨

26 Likes

这个是大佬,顶一个

3 Likes

感谢大佬分享!

2 Likes

感谢你的分享

2 Likes

#GraphRAG添加

#Python添加

开发调优资源荟萃

顶顶大佬一个

3 Likes

大佬强,顶一个!

1 Like

强的雅痞 :tieba_025: :tieba_014:

1 Like

强强强,顶一个

1 Like

佬友不是我愛說大實話 這個ai寫的嗎 我之前問rag的時候給的參數名稱就是這個

1 Like

感谢大佬的分享

1 Like

佬,码了,感谢

1 Like

mark一下,感谢

1 Like

不好意的,没有太看懂

感谢大佬分享

1 Like

沒事佬友 早上唐突了 不論是怎麼來的 分享自己的一些東西就是不錯的

真大佬,等能看懂的时候再实践下。。。

1 Like

这居然是个教程,可怕