我有这个想法,但我不是专业搞ai的吗
我看到 github-copilot-proxies 这个 Github Copilot 代理仓库
我想着使用 RAG技术将我们本地的git仓库向量化,
然后修改github-copilot-proxies发请求给第三方大模型的参数
注入本地代码仓库相关的信息。从而实现更高的接受率
下面是我写的demo,这里只写了向量化的代码
有几个疑问:
- 提示词怎么构建才能发挥出作用?
2.如何生成向量数据库的 query?
各位佬友一起看看,觉得这个想法有没有作用~
求指导
# Path: splitter .py
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_text_splitters.base import Language
EXTENSION_TO_LANGUAGE = {
"js": Language.JS,
"jsx": Language.JS,
"mjs": Language.JS,
"cjs": Language.JS,
"py": Language.PYTHON,
"java": Language.JAVA,
"cpp": Language.CPP,
"c": Language.C,
"html": Language.HTML,
"vue": Language.HTML,
# "css": 'css', # Unsupported
"json": Language.JS,
# "xml": "xml", # Unsupported
# "yaml": "yaml", # Unsupported
"md": Language.MARKDOWN,
# "sh": 'shell', # Unsupported
"ts": Language.TS,
"tsx": Language.TS,
"rs": Language.RUST,
"go": Language.GO,
"rb": Language.RUBY,
"php": Language.PHP,
# "sql": '',
"swift": Language.SWIFT,
"kt": Language.KOTLIN,
"scala": Language.SCALA,
"lua": Language.LUA,
"pl": Language.PERL,
# "r": 'r', # Unsupported
# "dart": 'dart', # Unsupported
# "elm": "elm", # Unsupported
# "clj": "clojure", # Unsupported
"hs": Language.HASKELL,
# "erl": "erlang",# Unsupported
# "ex": "elixir",# Unsupported
# "fs": "fsharp",# Unsupported
# "vb": "visualbasic",# Unsupported
# "asm": "asm",# Unsupported
# "s": "asm",# Unsupported
# "v": "verilog",# Unsupported
# "vhdl": "vhdl",# Unsupported
# "dockerfile": "dockerfile",# Unsupported
# "makefile": "make",# Unsupported
# "cmake": "cmake",# Unsupported
# "ninja": "ninja",# Unsupported
# "nix": "nix",# Unsupported
# "zig": "zig",# Unsupported
# "magik": "magik",# Unsupported
}
splitter_cache = dict()
def get_splitter(language: Language, **kwargs):
splitter = splitter_cache.get(str(language))
if splitter is None:
splitter = RecursiveCharacterTextSplitter.from_language(
language=language,
**kwargs,
)
splitter_cache[str(language)] = splitter
return splitter
def get_splitter_by_ext(ext: str) -> RecursiveCharacterTextSplitter:
ext = ext.lower()
if ext.startswith("."):
ext = ext[1:]
language = EXTENSION_TO_LANGUAGE.get(ext.lower(), "unknown")
if language == "unknown":
# default splitter
return RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=200,
separators=["\n\n", "\n", "}", ";", ")"],
)
return get_splitter(
language,
chunk_size=1000,
chunk_overlap=200,
)
__all__ = ["get_splitter_by_ext"]
# Path: demo.py
from typing import List, Union
import requests
from langchain_community.document_loaders import GitLoader
from langchain_community.embeddings import OpenAIEmbeddings
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_chroma import Chroma
from splitter import get_splitter_by_ext
# 配置API密钥
API_KEY = "sk-xxx"
EMBED_API_URL = "https://api.siliconflow.cn/v1/embeddings"
MODEL_NAME = "Pro/BAAI/bge-m3"
DEFAULT_PERSIST_DIR = "./chroma_api_db"
class APICodeEmbedder(Embeddings):
def __init__(self):
self._session = requests.Session()
def embed_documents(self, texts: list[str]) -> list[list[float]]:
"""批量处理文档"""
embeddings = []
for text in texts:
resp = self._session.post(
EMBED_API_URL,
headers={"Authorization": f"Bearer {API_KEY}"},
json={"input": text, "model": MODEL_NAME, "encoding_type": "float"},
timeout=30,
)
if resp.status_code == 200:
ret = resp.json()["data"][0]["embedding"]
print("嵌入成功", ret)
embeddings.append(ret)
else:
raise Exception(f"Embedding失败: {resp.text}")
return embeddings
def embed_query(self, text: str) -> List[float]:
"""Embed a query using GPT4All.
Args:
text: The text to embed.
Returns:
Embeddings for the text.
"""
return self.embed_documents([text])[0]
class GitChromaFetcher:
def __init__(
self,
repo_path: str,
persist_dir=None,
embedder: Embeddings = None,
loader_args: dict = None,
embedder_args: dict = None,
):
if not loader_args:
loader_args = {}
if not embedder_args:
embedder_args = {}
self._loader = GitLoader(repo_path, **loader_args)
self._result_docs = []
self._persist_dir = persist_dir or DEFAULT_PERSIST_DIR
self._embedder = embedder or OpenAIEmbeddings(
model=MODEL_NAME,
openai_api_key=API_KEY,
openai_api_base=EMBED_API_URL,
**embedder_args,
)
def _split_doc(self, document: Union[Document, List[Document]]):
splitter = get_splitter_by_ext(document.metadata.get("file_type"))
docs = document if isinstance(document, list) else [document]
self._result_docs += splitter.split_documents(docs)
def vectorization(self) -> List[Document]:
for document in self._loader.lazy_load():
self._split_doc(document)
return Chroma.from_documents(
self._result_docs,
embedding=self._embedder,
persist_directory=self._persist_dir,
)
@property
def chroma(self):
return Chroma(
persist_directory=self._persist_dir,
embedding_function=self._embedder,
)
def main():
repo_path = r"D:\Workspace\repo"
gcf = GitChromaFetcher(
repo_path,
embedder=APICodeEmbedder(),
loader_args={"branch": "master-v2"},
)
gcf.vectorization()
results = gcf.chroma.similarity_search("用户")
# 将查询到的注入到提示词中。
for doc in results:
print("=" * 30)
print(f"文件:{doc.metadata['file_path']}")
print(f"{doc.page_content}")
print("=" * 30)
if __name__ == "__main__":
main()