diff --git a/ChatView/ChatRootView.cpp b/ChatView/ChatRootView.cpp index 6846599..88d15ef 100644 --- a/ChatView/ChatRootView.cpp +++ b/ChatView/ChatRootView.cpp @@ -452,17 +452,36 @@ void ChatRootView::openChatHistoryFolder() void ChatRootView::testRAG() { auto project = ProjectExplorer::ProjectTree::currentProject(); - if (project) { - auto files = Context::ContextManager::instance().getProjectSourceFiles(project); - auto future = Context::RAGManager::instance().processFiles(project, files); - connect( - &Context::RAGManager::instance(), - &Context::RAGManager::vectorizationProgress, - this, - [](int processed, int total) { - qDebug() << "Processed" << processed << "of" << total << "files"; - }); + if (!project) { + qDebug() << "No active project found"; + return; } + + const QString TEST_QUERY = ""; + + qDebug() << "Starting RAG test with query:"; + qDebug() << TEST_QUERY; + qDebug() << "\nFirst, processing project files..."; + + auto files = Context::ContextManager::instance().getProjectSourceFiles(project); + auto future = Context::RAGManager::instance().processFiles(project, files); + + connect( + &Context::RAGManager::instance(), + &Context::RAGManager::vectorizationProgress, + this, + [](int processed, int total) { + qDebug() << QString("Vectorization progress: %1 of %2 files").arg(processed).arg(total); + }); + + connect( + &Context::RAGManager::instance(), + &Context::RAGManager::vectorizationFinished, + this, + [this, project, TEST_QUERY]() { + qDebug() << "\nVectorization completed. Starting similarity search...\n"; + Context::RAGManager::instance().searchSimilarDocuments(TEST_QUERY, project, 5); + }); } void ChatRootView::updateInputTokensCount() diff --git a/context/CMakeLists.txt b/context/CMakeLists.txt index 6f3a37b..132e127 100644 --- a/context/CMakeLists.txt +++ b/context/CMakeLists.txt @@ -9,6 +9,7 @@ add_library(Context STATIC RAGStorage.hpp RAGStorage.cpp RAGData.hpp RAGVectorizer.hpp RAGVectorizer.cpp + RAGSimilaritySearch.hpp RAGSimilaritySearch.cpp ) target_link_libraries(Context diff --git a/context/ContextManager.cpp b/context/ContextManager.cpp index 1e1785d..9ca7029 100644 --- a/context/ContextManager.cpp +++ b/context/ContextManager.cpp @@ -67,6 +67,28 @@ ContentFile ContextManager::createContentFile(const QString &filePath) const return contentFile; } +bool ContextManager::isInBuildDirectory(const QString &filePath) const +{ + static const QStringList buildDirPatterns + = {"/build/", + "/Build/", + "/BUILD/", + "/debug/", + "/Debug/", + "/DEBUG/", + "/release/", + "/Release/", + "/RELEASE/", + "/builds/"}; + + for (const QString &pattern : buildDirPatterns) { + if (filePath.contains(pattern)) { + return true; + } + } + return false; +} + QStringList ContextManager::getProjectSourceFiles(ProjectExplorer::Project *project) const { QStringList sourceFiles; @@ -79,8 +101,11 @@ QStringList ContextManager::getProjectSourceFiles(ProjectExplorer::Project *proj projectNode->forEachNode( [&sourceFiles, this](ProjectExplorer::FileNode *fileNode) { - if (fileNode && shouldProcessFile(fileNode->filePath().toString())) { - sourceFiles.append(fileNode->filePath().toString()); + if (fileNode) { + QString filePath = fileNode->filePath().toString(); + if (shouldProcessFile(filePath) && !isInBuildDirectory(filePath)) { + sourceFiles.append(filePath); + } } }, nullptr); diff --git a/context/ContextManager.hpp b/context/ContextManager.hpp index 3e67ed5..984b912 100644 --- a/context/ContextManager.hpp +++ b/context/ContextManager.hpp @@ -48,6 +48,7 @@ private: ContentFile createContentFile(const QString &filePath) const; bool shouldProcessFile(const QString &filePath) const; + bool isInBuildDirectory(const QString &filePath) const; }; } // namespace QodeAssist::Context diff --git a/context/RAGManager.cpp b/context/RAGManager.cpp index b48e701..9171200 100644 --- a/context/RAGManager.cpp +++ b/context/RAGManager.cpp @@ -18,11 +18,14 @@ */ #include "RAGManager.hpp" +#include "RAGSimilaritySearch.hpp" +#include "logger/Logger.hpp" #include #include #include #include +#include namespace QodeAssist::Context { @@ -39,6 +42,13 @@ RAGManager::RAGManager(QObject *parent) RAGManager::~RAGManager() {} +bool RAGManager::SearchResult::operator<(const SearchResult &other) const +{ + if (cosineScore != other.cosineScore) + return cosineScore > other.cosineScore; + return l2Score < other.l2Score; +} + QString RAGManager::getStoragePath(ProjectExplorer::Project *project) const { return QString("%1/qodeassist/%2/rag/vectors.db") @@ -165,7 +175,11 @@ QFuture RAGManager::processFile(ProjectExplorer::Project *project, const Q return promise->future(); } - auto vectorFuture = m_vectorizer->vectorizeText(QString::fromUtf8(file.readAll())); + QFileInfo fileInfo(filePath); + QString fileName = fileInfo.fileName(); + QString content = QString("// %1\n%2").arg(fileName, QString::fromUtf8(file.readAll())); + + auto vectorFuture = m_vectorizer->vectorizeText(content); vectorFuture.then([promise, filePath, this](const RAGVector &vector) { if (vector.empty()) { promise->addResult(false); @@ -214,4 +228,67 @@ bool RAGManager::isFileStorageOutdated( return m_currentStorage->needsUpdate(filePath); } +QFuture> RAGManager::search( + const QString &text, ProjectExplorer::Project *project, int topK) +{ + auto promise = std::make_shared>>(); + promise->start(); + + auto queryVectorFuture = m_vectorizer->vectorizeText(text); + queryVectorFuture.then([this, promise, project, topK](const RAGVector &queryVector) { + if (queryVector.empty()) { + LOG_MESSAGE("Failed to vectorize query text"); + promise->addResult(QList()); + promise->finish(); + return; + } + + auto storedFiles = getStoredFiles(project); + std::priority_queue results; + + for (const auto &filePath : storedFiles) { + auto storedVector = loadVectorFromStorage(project, filePath); + if (!storedVector.has_value()) + continue; + + float l2Score = RAGSimilaritySearch::l2Distance(queryVector, storedVector.value()); + float cosineScore + = RAGSimilaritySearch::cosineSimilarity(queryVector, storedVector.value()); + + results.push(SearchResult{filePath, l2Score, cosineScore}); + } + + QList resultsList; + int count = 0; + while (!results.empty() && count < topK) { + resultsList.append(results.top()); + results.pop(); + count++; + } + + promise->addResult(resultsList); + promise->finish(); + }); + + return promise->future(); +} + +void RAGManager::searchSimilarDocuments( + const QString &text, ProjectExplorer::Project *project, int topK) +{ + auto future = search(text, project, topK); + future.then([this](const QList &results) { logSearchResults(results); }); +} + +void RAGManager::logSearchResults(const QList &results) const +{ + qDebug() << QString("\nTop %1 similar documents:").arg(results.size()); + + for (const auto &result : results) { + qDebug() << QString("File: %1").arg(result.filePath); + qDebug() << QString(" Cosine Similarity: %1").arg(result.cosineScore); + qDebug() << QString(" L2 Distance: %1\n").arg(result.l2Score); + } +} + } // namespace QodeAssist::Context diff --git a/context/RAGManager.hpp b/context/RAGManager.hpp index 3c140e5..e54e733 100644 --- a/context/RAGManager.hpp +++ b/context/RAGManager.hpp @@ -40,11 +40,28 @@ class RAGManager : public QObject public: static RAGManager &instance(); + struct SearchResult + { + QString filePath; + float l2Score; + float cosineScore; + + bool operator<(const SearchResult &other) const; + }; + + // Process and vectorize files QFuture processFiles(ProjectExplorer::Project *project, const QStringList &filePaths); std::optional loadVectorFromStorage( ProjectExplorer::Project *project, const QString &filePath); QStringList getStoredFiles(ProjectExplorer::Project *project) const; bool isFileStorageOutdated(ProjectExplorer::Project *project, const QString &filePath) const; + RAGVectorizer *getVectorizer() const { return m_vectorizer.get(); } + + // Search functionality + QFuture> search( + const QString &text, ProjectExplorer::Project *project, int topK = 5); + void searchSimilarDocuments(const QString &text, ProjectExplorer::Project *project, int topK = 5); + void logSearchResults(const QList &results) const; signals: void vectorizationProgress(int processed, int total); @@ -53,6 +70,8 @@ signals: private: RAGManager(QObject *parent = nullptr); ~RAGManager(); + RAGManager(const RAGManager &) = delete; + RAGManager &operator=(const RAGManager &) = delete; QFuture processFile(ProjectExplorer::Project *project, const QString &filePath); void processNextBatch( diff --git a/context/RAGSimilaritySearch.cpp b/context/RAGSimilaritySearch.cpp new file mode 100644 index 0000000..535d192 --- /dev/null +++ b/context/RAGSimilaritySearch.cpp @@ -0,0 +1,67 @@ +/* + * Copyright (C) 2024 Petr Mironychev + * + * This file is part of QodeAssist. + * + * QodeAssist is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * QodeAssist is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with QodeAssist. If not, see . + */ + +#include "RAGSimilaritySearch.hpp" +#include "logger/Logger.hpp" + +#include + +namespace QodeAssist::Context { + +float RAGSimilaritySearch::l2Distance(const RAGVector &v1, const RAGVector &v2) +{ + if (v1.size() != v2.size()) { + LOG_MESSAGE(QString("Vector size mismatch: %1 vs %2").arg(v1.size()).arg(v2.size())); + return std::numeric_limits::max(); + } + + float sum = 0.0f; + for (size_t i = 0; i < v1.size(); ++i) { + float diff = v1[i] - v2[i]; + sum += diff * diff; + } + return std::sqrt(sum); +} + +float RAGSimilaritySearch::cosineSimilarity(const RAGVector &v1, const RAGVector &v2) +{ + if (v1.size() != v2.size()) { + LOG_MESSAGE(QString("Vector size mismatch: %1 vs %2").arg(v1.size()).arg(v2.size())); + return 0.0f; + } + + float dotProduct = 0.0f; + float norm1 = 0.0f; + float norm2 = 0.0f; + + for (size_t i = 0; i < v1.size(); ++i) { + dotProduct += v1[i] * v2[i]; + norm1 += v1[i] * v1[i]; + norm2 += v2[i] * v2[i]; + } + + norm1 = std::sqrt(norm1); + norm2 = std::sqrt(norm2); + + if (norm1 == 0.0f || norm2 == 0.0f) + return 0.0f; + return dotProduct / (norm1 * norm2); +} + +} // namespace QodeAssist::Context diff --git a/context/RAGSimilaritySearch.hpp b/context/RAGSimilaritySearch.hpp new file mode 100644 index 0000000..9acf67b --- /dev/null +++ b/context/RAGSimilaritySearch.hpp @@ -0,0 +1,37 @@ +/* + * Copyright (C) 2024 Petr Mironychev + * + * This file is part of QodeAssist. + * + * QodeAssist is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * QodeAssist is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with QodeAssist. If not, see . + */ + +#pragma once + +#include "RAGData.hpp" + +namespace QodeAssist::Context { + +class RAGSimilaritySearch +{ +public: + static float l2Distance(const RAGVector &v1, const RAGVector &v2); + + static float cosineSimilarity(const RAGVector &v1, const RAGVector &v2); + +private: + RAGSimilaritySearch() = delete; +}; + +} // namespace QodeAssist::Context