想法:使用 RAG 对 代理的 copilot 调优

我有这个想法,但我不是专业搞ai的吗

我看到 github-copilot-proxies 这个 Github Copilot 代理仓库
我想着使用 RAG技术将我们本地的git仓库向量化,
然后修改github-copilot-proxies发请求给第三方大模型的参数
注入本地代码仓库相关的信息。从而实现更高的接受率

下面是我写的demo,这里只写了向量化的代码
有几个疑问:

  1. 提示词怎么构建才能发挥出作用?
    2.如何生成向量数据库的 query?
    各位佬友一起看看,觉得这个想法有没有作用~
    求指导 :dizzy:
# Path: splitter .py
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_text_splitters.base import Language


EXTENSION_TO_LANGUAGE = {
    "js": Language.JS,
    "jsx": Language.JS,
    "mjs": Language.JS,
    "cjs": Language.JS,
    "py": Language.PYTHON,
    "java": Language.JAVA,
    "cpp": Language.CPP,
    "c": Language.C,
    "html": Language.HTML,
    "vue": Language.HTML,
    # "css": 'css', # Unsupported
    "json": Language.JS,
    # "xml": "xml", # Unsupported
    # "yaml": "yaml", # Unsupported
    "md": Language.MARKDOWN,
    # "sh": 'shell', # Unsupported
    "ts": Language.TS,
    "tsx": Language.TS,
    "rs": Language.RUST,
    "go": Language.GO,
    "rb": Language.RUBY,
    "php": Language.PHP,
    # "sql": '',
    "swift": Language.SWIFT,
    "kt": Language.KOTLIN,
    "scala": Language.SCALA,
    "lua": Language.LUA,
    "pl": Language.PERL,
    # "r": 'r', # Unsupported
    # "dart": 'dart', # Unsupported
    # "elm": "elm", # Unsupported
    # "clj": "clojure", # Unsupported
    "hs": Language.HASKELL,
    # "erl": "erlang",# Unsupported
    # "ex": "elixir",# Unsupported
    # "fs": "fsharp",# Unsupported
    # "vb": "visualbasic",# Unsupported
    # "asm": "asm",# Unsupported
    # "s": "asm",# Unsupported
    # "v": "verilog",# Unsupported
    # "vhdl": "vhdl",# Unsupported
    # "dockerfile": "dockerfile",# Unsupported
    # "makefile": "make",# Unsupported
    # "cmake": "cmake",# Unsupported
    # "ninja": "ninja",# Unsupported
    # "nix": "nix",# Unsupported
    # "zig": "zig",# Unsupported
    # "magik": "magik",# Unsupported
}


splitter_cache = dict()


def get_splitter(language: Language, **kwargs):
    splitter = splitter_cache.get(str(language))
    if splitter is None:
        splitter = RecursiveCharacterTextSplitter.from_language(
            language=language,
            **kwargs,
        )
        splitter_cache[str(language)] = splitter
    return splitter


def get_splitter_by_ext(ext: str) -> RecursiveCharacterTextSplitter:
    ext = ext.lower()
    if ext.startswith("."):
        ext = ext[1:]
    language = EXTENSION_TO_LANGUAGE.get(ext.lower(), "unknown")
    if language == "unknown":
        # default splitter
        return RecursiveCharacterTextSplitter(
            chunk_size=1000,
            chunk_overlap=200,
            separators=["\n\n", "\n", "}", ";", ")"],
        )
    return get_splitter(
        language,
        chunk_size=1000,
        chunk_overlap=200,
    )


__all__ = ["get_splitter_by_ext"]
# Path: demo.py
from typing import List, Union

import requests
from langchain_community.document_loaders import GitLoader
from langchain_community.embeddings import OpenAIEmbeddings
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_chroma import Chroma


from splitter import get_splitter_by_ext

# 配置API密钥
API_KEY = "sk-xxx"
EMBED_API_URL = "https://api.siliconflow.cn/v1/embeddings"
MODEL_NAME = "Pro/BAAI/bge-m3"
DEFAULT_PERSIST_DIR = "./chroma_api_db"


class APICodeEmbedder(Embeddings):
    def __init__(self):
        self._session = requests.Session()

    def embed_documents(self, texts: list[str]) -> list[list[float]]:
        """批量处理文档"""
        embeddings = []
        for text in texts:
            resp = self._session.post(
                EMBED_API_URL,
                headers={"Authorization": f"Bearer {API_KEY}"},
                json={"input": text, "model": MODEL_NAME, "encoding_type": "float"},
                timeout=30,
            )
            if resp.status_code == 200:
                ret = resp.json()["data"][0]["embedding"]
                print("嵌入成功", ret)
                embeddings.append(ret)
            else:
                raise Exception(f"Embedding失败: {resp.text}")
        return embeddings

    def embed_query(self, text: str) -> List[float]:
        """Embed a query using GPT4All.

        Args:
            text: The text to embed.

        Returns:
            Embeddings for the text.
        """
        return self.embed_documents([text])[0]


class GitChromaFetcher:
    def __init__(
        self,
        repo_path: str,
        persist_dir=None,
        embedder: Embeddings = None,
        loader_args: dict = None,
        embedder_args: dict = None,
    ):
        if not loader_args:
            loader_args = {}
        if not embedder_args:
            embedder_args = {}
        self._loader = GitLoader(repo_path, **loader_args)
        self._result_docs = []
        self._persist_dir = persist_dir or DEFAULT_PERSIST_DIR
        self._embedder = embedder or OpenAIEmbeddings(
            model=MODEL_NAME,
            openai_api_key=API_KEY,
            openai_api_base=EMBED_API_URL,
            **embedder_args,
        )

    def _split_doc(self, document: Union[Document, List[Document]]):
        splitter = get_splitter_by_ext(document.metadata.get("file_type"))
        docs = document if isinstance(document, list) else [document]
        self._result_docs += splitter.split_documents(docs)

    def vectorization(self) -> List[Document]:
        for document in self._loader.lazy_load():
            self._split_doc(document)
        return Chroma.from_documents(
            self._result_docs,
            embedding=self._embedder,
            persist_directory=self._persist_dir,
        )

    @property
    def chroma(self):
        return Chroma(
            persist_directory=self._persist_dir,
            embedding_function=self._embedder,
        )


def main():
    repo_path = r"D:\Workspace\repo"
    gcf = GitChromaFetcher(
        repo_path,
        embedder=APICodeEmbedder(),
        loader_args={"branch": "master-v2"},
    )
    gcf.vectorization()
    results = gcf.chroma.similarity_search("用户")
    # 将查询到的注入到提示词中。
    for doc in results:
        print("=" * 30)
        print(f"文件:{doc.metadata['file_path']}")
        print(f"{doc.page_content}")
        print("=" * 30)
    


if __name__ == "__main__":
    main()

1 Like

现在作者只适配了阿里的向量模型,我觉得你可以看一下这一部分的代码,改成openia标准格式的,本地接向量化模型
提示词肯定是copilot自己生成的好

emmm 好像 jb 的插件还支持使用嵌入模型?

jb不支持吧,vscode现在支持
作者已经更改成openai格式的embedding了

提示词是github copilot插件自己生成的,
所以才需要自己搭服务端,就是为了白嫖它生成的提示词
你只需要管embedding就行

我想的是在他的提示词再加一层:
针对当前项目的已经嵌入的代码库搜索一次,再追加到GitHub Copilot 生成的提示词上,然后发给 deepseek;
但是这个”搜索一次“直接用GitHub Copilot 生成提示词直接去搜索会不会太长了反而效果不咋地,可能要精简一下?不知道咋精简。

此话题已在最后回复的 30 天后被自动关闭。不再允许新回复。