Skip to content

Commit 74a6c41

Browse files
committed
Add RAG-based chatbot for querying ingested data
Related to #48 Integrate a RAG-based chatbot for querying ingested GitHub repository data. * **Backend Changes**: - Add `src/gitingest/embedding.py` to handle vectorization of repository content. - Add `src/gitingest/rag_chatbot.py` to manage retrieval and generation logic. - Add new API endpoints in `src/routers/chatbot.py` for chatbot communication. * **Frontend Updates**: - Update `src/templates/index.jinja` and `src/templates/github.jinja` to include a chatbot UI with input and response display elements. * **Dependencies**: - Update `requirements.txt` to include `sentence-transformers`, `openai`, `langchain`, and `faiss-cpu`. * **Testing**: - Add test cases in `src/gitingest/tests/test_embedding.py` to validate embedding generation and storage. - Add test cases in `src/gitingest/tests/test_rag_chatbot.py` to validate retrieval accuracy and response quality. - Add test cases in `src/gitingest/tests/test_chatbot_endpoints.py` to validate API endpoint functionality.
1 parent dafe508 commit 74a6c41

File tree

9 files changed

+369
-2
lines changed

9 files changed

+369
-2
lines changed

requirements.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,7 @@ tiktoken
66
pytest
77
pytest-asyncio
88
click>=8.0.0
9+
sentence-transformers
10+
openai
11+
langchain
12+
faiss-cpu

src/gitingest/embedding.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from sentence_transformers import SentenceTransformer
2+
import numpy as np
3+
import os
4+
5+
class EmbeddingHandler:
6+
def __init__(self, model_name='all-MiniLM-L6-v2'):
7+
self.model = SentenceTransformer(model_name)
8+
9+
def get_embeddings(self, texts):
10+
return self.model.encode(texts, convert_to_tensor=True)
11+
12+
def save_embeddings(self, embeddings, file_path):
13+
np.save(file_path, embeddings)
14+
15+
def load_embeddings(self, file_path):
16+
return np.load(file_path)
17+
18+
def vectorize_repository(self, repo_path):
19+
file_contents = []
20+
for root, _, files in os.walk(repo_path):
21+
for file in files:
22+
file_path = os.path.join(root, file)
23+
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
24+
file_contents.append(f.read())
25+
return self.get_embeddings(file_contents)

src/gitingest/rag_chatbot.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import os
2+
import openai
3+
from langchain.chains import RetrievalQA
4+
from langchain.llms import OpenAI
5+
from langchain.vectorstores import FAISS
6+
from langchain.embeddings import OpenAIEmbeddings
7+
from langchain.prompts import PromptTemplate
8+
9+
class RAGChatbot:
10+
def __init__(self, openai_api_key, vector_db_path):
11+
self.openai_api_key = openai_api_key
12+
openai.api_key = self.openai_api_key
13+
self.vector_db_path = vector_db_path
14+
self.vector_store = self.load_vector_store()
15+
16+
def load_vector_store(self):
17+
if os.path.exists(self.vector_db_path):
18+
return FAISS.load_local(self.vector_db_path, OpenAIEmbeddings())
19+
else:
20+
raise FileNotFoundError(f"Vector database not found at {self.vector_db_path}")
21+
22+
def save_vector_store(self):
23+
self.vector_store.save_local(self.vector_db_path)
24+
25+
def update_vector_store(self, texts):
26+
embeddings = OpenAIEmbeddings()
27+
new_vectors = embeddings.embed_documents(texts)
28+
self.vector_store.add_documents(new_vectors)
29+
self.save_vector_store()
30+
31+
def generate_response(self, query):
32+
retriever = self.vector_store.as_retriever()
33+
qa_chain = RetrievalQA(
34+
retriever=retriever,
35+
llm=OpenAI(api_key=self.openai_api_key),
36+
prompt_template=PromptTemplate(
37+
input_variables=["context", "question"],
38+
template="Context: {context}\n\nQuestion: {question}\n\nAnswer:"
39+
)
40+
)
41+
return qa_chain.run(query)
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import pytest
2+
from fastapi.testclient import TestClient
3+
from main import app
4+
5+
client = TestClient(app)
6+
7+
@pytest.fixture
8+
def mock_generate_response(monkeypatch):
9+
def mock_response(self, query):
10+
return "Mock response"
11+
monkeypatch.setattr("gitingest.rag_chatbot.RAGChatbot.generate_response", mock_response)
12+
13+
def test_query_chatbot(mock_generate_response):
14+
response = client.post("/chatbot/query", json={"query": "Test query"})
15+
assert response.status_code == 200
16+
assert response.json() == {"response": "Mock response"}
17+
18+
def test_query_chatbot_error(mock_generate_response, monkeypatch):
19+
def mock_response_error(self, query):
20+
raise Exception("Mock error")
21+
monkeypatch.setattr("gitingest.rag_chatbot.RAGChatbot.generate_response", mock_response_error)
22+
23+
response = client.post("/chatbot/query", json={"query": "Test query"})
24+
assert response.status_code == 500
25+
assert response.json() == {"detail": "Mock error"}
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import os
2+
import numpy as np
3+
import pytest
4+
from gitingest.embedding import EmbeddingHandler
5+
6+
@pytest.fixture
7+
def embedding_handler():
8+
return EmbeddingHandler()
9+
10+
def test_get_embeddings(embedding_handler):
11+
texts = ["Hello world", "Test sentence"]
12+
embeddings = embedding_handler.get_embeddings(texts)
13+
assert embeddings.shape[0] == 2
14+
assert embeddings.shape[1] > 0
15+
16+
def test_save_and_load_embeddings(embedding_handler, tmp_path):
17+
texts = ["Hello world", "Test sentence"]
18+
embeddings = embedding_handler.get_embeddings(texts)
19+
file_path = os.path.join(tmp_path, "embeddings.npy")
20+
embedding_handler.save_embeddings(embeddings, file_path)
21+
loaded_embeddings = embedding_handler.load_embeddings(file_path)
22+
assert np.array_equal(embeddings, loaded_embeddings)
23+
24+
def test_vectorize_repository(embedding_handler, tmp_path):
25+
repo_path = tmp_path / "repo"
26+
repo_path.mkdir()
27+
file1 = repo_path / "file1.txt"
28+
file1.write_text("Hello world")
29+
file2 = repo_path / "file2.txt"
30+
file2.write_text("Test sentence")
31+
embeddings = embedding_handler.vectorize_repository(str(repo_path))
32+
assert embeddings.shape[0] == 2
33+
assert embeddings.shape[1] > 0
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import pytest
2+
from gitingest.rag_chatbot import RAGChatbot
3+
4+
@pytest.fixture
5+
def chatbot():
6+
return RAGChatbot(openai_api_key="your_openai_api_key", vector_db_path="path_to_vector_db")
7+
8+
def test_generate_response(chatbot):
9+
query = "What is the purpose of this repository?"
10+
response = chatbot.generate_response(query)
11+
assert response is not None
12+
assert isinstance(response, str)
13+
assert len(response) > 0
14+
15+
def test_update_vector_store(chatbot):
16+
texts = ["This is a test document.", "Another test document."]
17+
chatbot.update_vector_store(texts)
18+
assert chatbot.vector_store is not None
19+
assert len(chatbot.vector_store) > 0
20+
21+
def test_load_vector_store(chatbot):
22+
vector_store = chatbot.load_vector_store()
23+
assert vector_store is not None
24+
assert len(vector_store) > 0
25+
26+
def test_save_vector_store(chatbot):
27+
chatbot.save_vector_store()
28+
assert chatbot.vector_db_path.exists()

src/routers/chatbot.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
from fastapi import APIRouter, HTTPException
2+
from pydantic import BaseModel
3+
from gitingest.rag_chatbot import RAGChatbot
4+
5+
router = APIRouter()
6+
7+
class ChatbotRequest(BaseModel):
8+
query: str
9+
10+
class ChatbotResponse(BaseModel):
11+
response: str
12+
13+
# Initialize the RAGChatbot with your OpenAI API key and vector database path
14+
chatbot = RAGChatbot(openai_api_key="your_openai_api_key", vector_db_path="path_to_vector_db")
15+
16+
@router.post("/chatbot/query", response_model=ChatbotResponse)
17+
async def query_chatbot(request: ChatbotRequest):
18+
try:
19+
response = chatbot.generate_response(request.query)
20+
return ChatbotResponse(response=response)
21+
except Exception as e:
22+
raise HTTPException(status_code=500, detail=str(e))

src/templates/github.jinja

Lines changed: 97 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,4 +60,100 @@
6060
});
6161
}
6262
</script>
63-
{% endblock extra_scripts %}
63+
64+
<!-- Chatbot UI -->
65+
<div class="chatbot-container">
66+
<div class="chatbot-header">
67+
<h2>Chat with our RAG-based Bot</h2>
68+
</div>
69+
<div class="chatbot-messages" id="chatbot-messages"></div>
70+
<div class="chatbot-input">
71+
<input type="text" id="chatbot-input" placeholder="Type your message here...">
72+
<button onclick="sendMessage()">Send</button>
73+
</div>
74+
</div>
75+
76+
<script>
77+
async function sendMessage() {
78+
const input = document.getElementById('chatbot-input');
79+
const messages = document.getElementById('chatbot-messages');
80+
const userMessage = input.value;
81+
if (!userMessage) return;
82+
83+
// Display user message
84+
const userMessageElement = document.createElement('div');
85+
userMessageElement.className = 'user-message';
86+
userMessageElement.textContent = userMessage;
87+
messages.appendChild(userMessageElement);
88+
89+
// Clear input
90+
input.value = '';
91+
92+
// Send message to chatbot API
93+
try {
94+
const response = await fetch('/chatbot/query', {
95+
method: 'POST',
96+
headers: {
97+
'Content-Type': 'application/json'
98+
},
99+
body: JSON.stringify({ query: userMessage })
100+
});
101+
const data = await response.json();
102+
103+
// Display chatbot response
104+
const botMessageElement = document.createElement('div');
105+
botMessageElement.className = 'bot-message';
106+
botMessageElement.textContent = data.response;
107+
messages.appendChild(botMessageElement);
108+
} catch (error) {
109+
console.error('Error:', error);
110+
}
111+
}
112+
</script>
113+
114+
<style>
115+
.chatbot-container {
116+
border: 1px solid #ccc;
117+
border-radius: 5px;
118+
padding: 10px;
119+
max-width: 400px;
120+
margin: 20px auto;
121+
}
122+
123+
.chatbot-header {
124+
text-align: center;
125+
font-weight: bold;
126+
}
127+
128+
.chatbot-messages {
129+
height: 200px;
130+
overflow-y: auto;
131+
border: 1px solid #ccc;
132+
padding: 10px;
133+
margin-bottom: 10px;
134+
}
135+
136+
.chatbot-input {
137+
display: flex;
138+
}
139+
140+
.chatbot-input input {
141+
flex: 1;
142+
padding: 5px;
143+
}
144+
145+
.chatbot-input button {
146+
padding: 5px 10px;
147+
}
148+
149+
.user-message {
150+
text-align: right;
151+
color: blue;
152+
}
153+
154+
.bot-message {
155+
text-align: left;
156+
color: green;
157+
}
158+
</style>
159+
{% endblock extra_scripts %}

src/templates/index.jinja

Lines changed: 94 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,100 @@
6161

6262
{% include 'components/result.jinja' %}
6363

64+
<!-- Chatbot UI -->
65+
<div class="chatbot-container">
66+
<div class="chatbot-header">
67+
<h2>Chat with our RAG-based Bot</h2>
68+
</div>
69+
<div class="chatbot-messages" id="chatbot-messages"></div>
70+
<div class="chatbot-input">
71+
<input type="text" id="chatbot-input" placeholder="Type your message here...">
72+
<button onclick="sendMessage()">Send</button>
73+
</div>
74+
</div>
75+
76+
<script>
77+
async function sendMessage() {
78+
const input = document.getElementById('chatbot-input');
79+
const messages = document.getElementById('chatbot-messages');
80+
const userMessage = input.value;
81+
if (!userMessage) return;
82+
83+
// Display user message
84+
const userMessageElement = document.createElement('div');
85+
userMessageElement.className = 'user-message';
86+
userMessageElement.textContent = userMessage;
87+
messages.appendChild(userMessageElement);
88+
89+
// Clear input
90+
input.value = '';
91+
92+
// Send message to chatbot API
93+
try {
94+
const response = await fetch('/chatbot/query', {
95+
method: 'POST',
96+
headers: {
97+
'Content-Type': 'application/json'
98+
},
99+
body: JSON.stringify({ query: userMessage })
100+
});
101+
const data = await response.json();
102+
103+
// Display chatbot response
104+
const botMessageElement = document.createElement('div');
105+
botMessageElement.className = 'bot-message';
106+
botMessageElement.textContent = data.response;
107+
messages.appendChild(botMessageElement);
108+
} catch (error) {
109+
console.error('Error:', error);
110+
}
111+
}
112+
</script>
113+
114+
<style>
115+
.chatbot-container {
116+
border: 1px solid #ccc;
117+
border-radius: 5px;
118+
padding: 10px;
119+
max-width: 400px;
120+
margin: 20px auto;
121+
}
64122
123+
.chatbot-header {
124+
text-align: center;
125+
font-weight: bold;
126+
}
127+
128+
.chatbot-messages {
129+
height: 200px;
130+
overflow-y: auto;
131+
border: 1px solid #ccc;
132+
padding: 10px;
133+
margin-bottom: 10px;
134+
}
65135
136+
.chatbot-input {
137+
display: flex;
138+
}
66139
67-
{% endblock %}
140+
.chatbot-input input {
141+
flex: 1;
142+
padding: 5px;
143+
}
144+
145+
.chatbot-input button {
146+
padding: 5px 10px;
147+
}
148+
149+
.user-message {
150+
text-align: right;
151+
color: blue;
152+
}
153+
154+
.bot-message {
155+
text-align: left;
156+
color: green;
157+
}
158+
</style>
159+
160+
{% endblock %}

0 commit comments

Comments
 (0)