From 74a6c417387787459840757c14356f50bb6d2bb9 Mon Sep 17 00:00:00 2001 From: Vishwanath Martur <64204611+vishwamartur@users.noreply.github.com> Date: Wed, 25 Dec 2024 03:19:14 +0530 Subject: [PATCH] Add RAG-based chatbot for querying ingested data Related to #48 Integrate a RAG-based chatbot for querying ingested GitHub repository data. * **Backend Changes**: - Add `src/gitingest/embedding.py` to handle vectorization of repository content. - Add `src/gitingest/rag_chatbot.py` to manage retrieval and generation logic. - Add new API endpoints in `src/routers/chatbot.py` for chatbot communication. * **Frontend Updates**: - Update `src/templates/index.jinja` and `src/templates/github.jinja` to include a chatbot UI with input and response display elements. * **Dependencies**: - Update `requirements.txt` to include `sentence-transformers`, `openai`, `langchain`, and `faiss-cpu`. * **Testing**: - Add test cases in `src/gitingest/tests/test_embedding.py` to validate embedding generation and storage. - Add test cases in `src/gitingest/tests/test_rag_chatbot.py` to validate retrieval accuracy and response quality. - Add test cases in `src/gitingest/tests/test_chatbot_endpoints.py` to validate API endpoint functionality. --- requirements.txt | 4 + src/gitingest/embedding.py | 25 +++++ src/gitingest/rag_chatbot.py | 41 ++++++++ src/gitingest/tests/test_chatbot_endpoints.py | 25 +++++ src/gitingest/tests/test_embedding.py | 33 +++++++ src/gitingest/tests/test_rag_chatbot.py | 28 ++++++ src/routers/chatbot.py | 22 +++++ src/templates/github.jinja | 98 ++++++++++++++++++- src/templates/index.jinja | 95 +++++++++++++++++- 9 files changed, 369 insertions(+), 2 deletions(-) create mode 100644 src/gitingest/embedding.py create mode 100644 src/gitingest/rag_chatbot.py create mode 100644 src/gitingest/tests/test_chatbot_endpoints.py create mode 100644 src/gitingest/tests/test_embedding.py create mode 100644 src/gitingest/tests/test_rag_chatbot.py create mode 100644 src/routers/chatbot.py diff --git a/requirements.txt b/requirements.txt index 6848603b..906987f9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,3 +6,7 @@ tiktoken pytest pytest-asyncio click>=8.0.0 +sentence-transformers +openai +langchain +faiss-cpu diff --git a/src/gitingest/embedding.py b/src/gitingest/embedding.py new file mode 100644 index 00000000..51c98c62 --- /dev/null +++ b/src/gitingest/embedding.py @@ -0,0 +1,25 @@ +from sentence_transformers import SentenceTransformer +import numpy as np +import os + +class EmbeddingHandler: + def __init__(self, model_name='all-MiniLM-L6-v2'): + self.model = SentenceTransformer(model_name) + + def get_embeddings(self, texts): + return self.model.encode(texts, convert_to_tensor=True) + + def save_embeddings(self, embeddings, file_path): + np.save(file_path, embeddings) + + def load_embeddings(self, file_path): + return np.load(file_path) + + def vectorize_repository(self, repo_path): + file_contents = [] + for root, _, files in os.walk(repo_path): + for file in files: + file_path = os.path.join(root, file) + with open(file_path, 'r', encoding='utf-8', errors='ignore') as f: + file_contents.append(f.read()) + return self.get_embeddings(file_contents) diff --git a/src/gitingest/rag_chatbot.py b/src/gitingest/rag_chatbot.py new file mode 100644 index 00000000..6946008d --- /dev/null +++ b/src/gitingest/rag_chatbot.py @@ -0,0 +1,41 @@ +import os +import openai +from langchain.chains import RetrievalQA +from langchain.llms import OpenAI +from langchain.vectorstores import FAISS +from langchain.embeddings import OpenAIEmbeddings +from langchain.prompts import PromptTemplate + +class RAGChatbot: + def __init__(self, openai_api_key, vector_db_path): + self.openai_api_key = openai_api_key + openai.api_key = self.openai_api_key + self.vector_db_path = vector_db_path + self.vector_store = self.load_vector_store() + + def load_vector_store(self): + if os.path.exists(self.vector_db_path): + return FAISS.load_local(self.vector_db_path, OpenAIEmbeddings()) + else: + raise FileNotFoundError(f"Vector database not found at {self.vector_db_path}") + + def save_vector_store(self): + self.vector_store.save_local(self.vector_db_path) + + def update_vector_store(self, texts): + embeddings = OpenAIEmbeddings() + new_vectors = embeddings.embed_documents(texts) + self.vector_store.add_documents(new_vectors) + self.save_vector_store() + + def generate_response(self, query): + retriever = self.vector_store.as_retriever() + qa_chain = RetrievalQA( + retriever=retriever, + llm=OpenAI(api_key=self.openai_api_key), + prompt_template=PromptTemplate( + input_variables=["context", "question"], + template="Context: {context}\n\nQuestion: {question}\n\nAnswer:" + ) + ) + return qa_chain.run(query) diff --git a/src/gitingest/tests/test_chatbot_endpoints.py b/src/gitingest/tests/test_chatbot_endpoints.py new file mode 100644 index 00000000..4c38a912 --- /dev/null +++ b/src/gitingest/tests/test_chatbot_endpoints.py @@ -0,0 +1,25 @@ +import pytest +from fastapi.testclient import TestClient +from main import app + +client = TestClient(app) + +@pytest.fixture +def mock_generate_response(monkeypatch): + def mock_response(self, query): + return "Mock response" + monkeypatch.setattr("gitingest.rag_chatbot.RAGChatbot.generate_response", mock_response) + +def test_query_chatbot(mock_generate_response): + response = client.post("/chatbot/query", json={"query": "Test query"}) + assert response.status_code == 200 + assert response.json() == {"response": "Mock response"} + +def test_query_chatbot_error(mock_generate_response, monkeypatch): + def mock_response_error(self, query): + raise Exception("Mock error") + monkeypatch.setattr("gitingest.rag_chatbot.RAGChatbot.generate_response", mock_response_error) + + response = client.post("/chatbot/query", json={"query": "Test query"}) + assert response.status_code == 500 + assert response.json() == {"detail": "Mock error"} diff --git a/src/gitingest/tests/test_embedding.py b/src/gitingest/tests/test_embedding.py new file mode 100644 index 00000000..9747b049 --- /dev/null +++ b/src/gitingest/tests/test_embedding.py @@ -0,0 +1,33 @@ +import os +import numpy as np +import pytest +from gitingest.embedding import EmbeddingHandler + +@pytest.fixture +def embedding_handler(): + return EmbeddingHandler() + +def test_get_embeddings(embedding_handler): + texts = ["Hello world", "Test sentence"] + embeddings = embedding_handler.get_embeddings(texts) + assert embeddings.shape[0] == 2 + assert embeddings.shape[1] > 0 + +def test_save_and_load_embeddings(embedding_handler, tmp_path): + texts = ["Hello world", "Test sentence"] + embeddings = embedding_handler.get_embeddings(texts) + file_path = os.path.join(tmp_path, "embeddings.npy") + embedding_handler.save_embeddings(embeddings, file_path) + loaded_embeddings = embedding_handler.load_embeddings(file_path) + assert np.array_equal(embeddings, loaded_embeddings) + +def test_vectorize_repository(embedding_handler, tmp_path): + repo_path = tmp_path / "repo" + repo_path.mkdir() + file1 = repo_path / "file1.txt" + file1.write_text("Hello world") + file2 = repo_path / "file2.txt" + file2.write_text("Test sentence") + embeddings = embedding_handler.vectorize_repository(str(repo_path)) + assert embeddings.shape[0] == 2 + assert embeddings.shape[1] > 0 diff --git a/src/gitingest/tests/test_rag_chatbot.py b/src/gitingest/tests/test_rag_chatbot.py new file mode 100644 index 00000000..88de7f63 --- /dev/null +++ b/src/gitingest/tests/test_rag_chatbot.py @@ -0,0 +1,28 @@ +import pytest +from gitingest.rag_chatbot import RAGChatbot + +@pytest.fixture +def chatbot(): + return RAGChatbot(openai_api_key="your_openai_api_key", vector_db_path="path_to_vector_db") + +def test_generate_response(chatbot): + query = "What is the purpose of this repository?" + response = chatbot.generate_response(query) + assert response is not None + assert isinstance(response, str) + assert len(response) > 0 + +def test_update_vector_store(chatbot): + texts = ["This is a test document.", "Another test document."] + chatbot.update_vector_store(texts) + assert chatbot.vector_store is not None + assert len(chatbot.vector_store) > 0 + +def test_load_vector_store(chatbot): + vector_store = chatbot.load_vector_store() + assert vector_store is not None + assert len(vector_store) > 0 + +def test_save_vector_store(chatbot): + chatbot.save_vector_store() + assert chatbot.vector_db_path.exists() diff --git a/src/routers/chatbot.py b/src/routers/chatbot.py new file mode 100644 index 00000000..5d51e160 --- /dev/null +++ b/src/routers/chatbot.py @@ -0,0 +1,22 @@ +from fastapi import APIRouter, HTTPException +from pydantic import BaseModel +from gitingest.rag_chatbot import RAGChatbot + +router = APIRouter() + +class ChatbotRequest(BaseModel): + query: str + +class ChatbotResponse(BaseModel): + response: str + +# Initialize the RAGChatbot with your OpenAI API key and vector database path +chatbot = RAGChatbot(openai_api_key="your_openai_api_key", vector_db_path="path_to_vector_db") + +@router.post("/chatbot/query", response_model=ChatbotResponse) +async def query_chatbot(request: ChatbotRequest): + try: + response = chatbot.generate_response(request.query) + return ChatbotResponse(response=response) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) diff --git a/src/templates/github.jinja b/src/templates/github.jinja index fdedcce7..7f64ecaf 100644 --- a/src/templates/github.jinja +++ b/src/templates/github.jinja @@ -60,4 +60,100 @@ }); } -{% endblock extra_scripts %} \ No newline at end of file + + +