feat: RAG init

This commit is contained in:
Petr Mironychev 2025-01-26 17:51:01 +01:00
parent 1fa6a225a4
commit 5a426b4d9f
14 changed files with 726 additions and 4 deletions

View File

@ -24,21 +24,23 @@
#include <QFileDialog>
#include <QMessageBox>
#include <coreplugin/editormanager/editormanager.h>
#include <coreplugin/icore.h>
#include <projectexplorer/project.h>
#include <projectexplorer/projectexplorer.h>
#include <projectexplorer/projectmanager.h>
#include <projectexplorer/projecttree.h>
#include <utils/theme/theme.h>
#include <utils/utilsicons.h>
#include <coreplugin/editormanager/editormanager.h>
#include "ChatAssistantSettings.hpp"
#include "ChatSerializer.hpp"
#include "GeneralSettings.hpp"
#include "Logger.hpp"
#include "ProjectSettings.hpp"
#include "context/TokenUtils.hpp"
#include "context/ContextManager.hpp"
#include "context/RAGManager.hpp"
#include "context/TokenUtils.hpp"
namespace QodeAssist::Chat {
@ -447,6 +449,22 @@ void ChatRootView::openChatHistoryFolder()
QDesktopServices::openUrl(url);
}
void ChatRootView::testRAG()
{
auto project = ProjectExplorer::ProjectTree::currentProject();
if (project) {
auto files = Context::ContextManager::instance().getProjectSourceFiles(project);
auto future = Context::RAGManager::instance().processFiles(project, files);
connect(
&Context::RAGManager::instance(),
&Context::RAGManager::vectorizationProgress,
this,
[](int processed, int total) {
qDebug() << "Processed" << processed << "of" << total << "files";
});
}
}
void ChatRootView::updateInputTokensCount()
{
int inputTokens = m_messageTokensCount;

View File

@ -64,7 +64,7 @@ public:
Q_INVOKABLE void removeFileFromLinkList(int index);
Q_INVOKABLE void calculateMessageTokensCount(const QString &message);
Q_INVOKABLE void setIsSyncOpenFiles(bool state);
Q_INVOKABLE void openChatHistoryFolder();
Q_INVOKABLE void testRAG();
Q_INVOKABLE void updateInputTokensCount();
int inputTokensCount() const;

View File

@ -198,6 +198,7 @@ ChatRootView {
}
attachFiles.onClicked: root.showAttachFilesDialog()
linkFiles.onClicked: root.showLinkFilesDialog()
testRag.onClicked: root.testRAG()
}
}

View File

@ -30,6 +30,7 @@ Rectangle {
property alias syncOpenFiles: syncOpenFilesId
property alias attachFiles: attachFilesId
property alias linkFiles: linkFilesId
property alias testRag: testRagId
color: palette.window.hslLightness > 0.5 ?
Qt.darker(palette.window, 1.1) :
@ -91,6 +92,12 @@ Rectangle {
ToolTip.text: qsTr("Automatically synchronize currently opened files with the model context")
}
QoAButton {
id: testRagId
text: qsTr("Test RAG")
}
Item {
Layout.fillWidth: true
}

View File

@ -5,11 +5,16 @@ add_library(Context STATIC
ContentFile.hpp
TokenUtils.hpp TokenUtils.cpp
ProgrammingLanguage.hpp ProgrammingLanguage.cpp
RAGManager.hpp RAGManager.cpp
RAGStorage.hpp RAGStorage.cpp
RAGData.hpp
RAGVectorizer.hpp RAGVectorizer.cpp
)
target_link_libraries(Context
PUBLIC
Qt::Core
Qt::Sql
QtCreator::Core
QtCreator::TextEditor
QtCreator::Utils

View File

@ -23,6 +23,9 @@
#include <QFileInfo>
#include <QTextStream>
#include <projectexplorer/project.h>
#include <projectexplorer/projectnodes.h>
namespace QodeAssist::Context {
ContextManager &ContextManager::instance()
@ -64,4 +67,34 @@ ContentFile ContextManager::createContentFile(const QString &filePath) const
return contentFile;
}
QStringList ContextManager::getProjectSourceFiles(ProjectExplorer::Project *project) const
{
QStringList sourceFiles;
if (!project)
return sourceFiles;
auto projectNode = project->rootProjectNode();
if (!projectNode)
return sourceFiles;
projectNode->forEachNode(
[&sourceFiles, this](ProjectExplorer::FileNode *fileNode) {
if (fileNode && shouldProcessFile(fileNode->filePath().toString())) {
sourceFiles.append(fileNode->filePath().toString());
}
},
nullptr);
return sourceFiles;
}
bool ContextManager::shouldProcessFile(const QString &filePath) const
{
static const QStringList supportedExtensions
= {"cpp", "hpp", "c", "h", "cc", "hh", "cxx", "hxx", "qml", "js", "py"};
QFileInfo fileInfo(filePath);
return supportedExtensions.contains(fileInfo.suffix().toLower());
}
} // namespace QodeAssist::Context

View File

@ -19,10 +19,13 @@
#pragma once
#include "ContentFile.hpp"
#include <QObject>
#include <QString>
#include "ContentFile.hpp"
namespace ProjectExplorer {
class Project;
}
namespace QodeAssist::Context {
@ -32,15 +35,19 @@ class ContextManager : public QObject
public:
static ContextManager &instance();
QString readFile(const QString &filePath) const;
QList<ContentFile> getContentFiles(const QStringList &filePaths) const;
QStringList getProjectSourceFiles(ProjectExplorer::Project *project) const;
private:
explicit ContextManager(QObject *parent = nullptr);
~ContextManager() = default;
ContextManager(const ContextManager &) = delete;
ContextManager &operator=(const ContextManager &) = delete;
ContentFile createContentFile(const QString &filePath) const;
bool shouldProcessFile(const QString &filePath) const;
};
} // namespace QodeAssist::Context

7
context/RAGData.hpp Normal file
View File

@ -0,0 +1,7 @@
#pragma once
#include <vector>
namespace QodeAssist::Context {
using RAGVector = std::vector<float>;
}

217
context/RAGManager.cpp Normal file
View File

@ -0,0 +1,217 @@
/*
* Copyright (C) 2024 Petr Mironychev
*
* This file is part of QodeAssist.
*
* QodeAssist is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* QodeAssist is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with QodeAssist. If not, see <https://www.gnu.org/licenses/>.
*/
#include "RAGManager.hpp"
#include <coreplugin/icore.h>
#include <projectexplorer/project.h>
#include <QFile>
#include <QtConcurrent>
namespace QodeAssist::Context {
RAGManager &RAGManager::instance()
{
static RAGManager manager;
return manager;
}
RAGManager::RAGManager(QObject *parent)
: QObject(parent)
, m_vectorizer(std::make_unique<RAGVectorizer>())
{}
RAGManager::~RAGManager() {}
QString RAGManager::getStoragePath(ProjectExplorer::Project *project) const
{
return QString("%1/qodeassist/%2/rag/vectors.db")
.arg(Core::ICore::userResourcePath().toString(), project->displayName());
}
void RAGManager::ensureStorageForProject(ProjectExplorer::Project *project)
{
if (m_currentProject == project && m_currentStorage) {
return;
}
m_currentStorage.reset();
m_currentProject = project;
if (project) {
m_currentStorage = std::make_unique<RAGStorage>(getStoragePath(project), this);
m_currentStorage->init();
}
}
QFuture<void> RAGManager::processFiles(
ProjectExplorer::Project *project, const QStringList &filePaths)
{
qDebug() << "Starting batch processing of" << filePaths.size()
<< "files for project:" << project->displayName();
auto promise = std::make_shared<QPromise<void>>();
promise->start();
ensureStorageForProject(project);
if (!m_currentStorage) {
qDebug() << "Failed to initialize storage for project:" << project->displayName();
promise->finish();
return promise->future();
}
const int batchSize = 10;
QStringList filesToProcess;
for (const QString &filePath : filePaths) {
if (isFileStorageOutdated(project, filePath)) {
qDebug() << "File needs processing:" << filePath;
filesToProcess.append(filePath);
}
}
if (filesToProcess.isEmpty()) {
qDebug() << "No files need processing";
emit vectorizationFinished();
promise->finish();
return promise->future();
}
qDebug() << "Processing" << filesToProcess.size() << "files in batches of" << batchSize;
processNextBatch(promise, project, filesToProcess, 0, batchSize);
return promise->future();
}
void RAGManager::processNextBatch(
std::shared_ptr<QPromise<void>> promise,
ProjectExplorer::Project *project,
const QStringList &files,
int startIndex,
int batchSize)
{
if (startIndex >= files.size()) {
qDebug() << "All batches processed";
emit vectorizationFinished();
promise->finish();
return;
}
int endIndex = qMin(startIndex + batchSize, files.size());
auto currentBatch = files.mid(startIndex, endIndex - startIndex);
qDebug() << "Processing batch" << startIndex / batchSize + 1 << "files" << startIndex << "to"
<< endIndex;
for (const QString &filePath : currentBatch) {
qDebug() << "Starting processing of file:" << filePath;
auto future = processFile(project, filePath);
auto watcher = new QFutureWatcher<bool>;
watcher->setFuture(future);
connect(
watcher,
&QFutureWatcher<bool>::finished,
this,
[this, watcher, promise, project, files, startIndex, endIndex, batchSize, filePath]() {
bool success = watcher->result();
qDebug() << "File processed:" << filePath << "success:" << success;
bool isLastFileInBatch = (filePath == files[endIndex - 1]);
if (isLastFileInBatch) {
qDebug() << "Batch completed, moving to next batch";
emit vectorizationProgress(endIndex, files.size());
processNextBatch(promise, project, files, endIndex, batchSize);
}
watcher->deleteLater();
});
}
}
QFuture<bool> RAGManager::processFile(ProjectExplorer::Project *project, const QString &filePath)
{
auto promise = std::make_shared<QPromise<bool>>();
promise->start();
ensureStorageForProject(project);
if (!m_currentStorage) {
promise->addResult(false);
promise->finish();
return promise->future();
}
QFile file(filePath);
if (!file.open(QIODevice::ReadOnly)) {
promise->addResult(false);
promise->finish();
return promise->future();
}
auto vectorFuture = m_vectorizer->vectorizeText(QString::fromUtf8(file.readAll()));
vectorFuture.then([promise, filePath, this](const RAGVector &vector) {
if (vector.empty()) {
promise->addResult(false);
} else {
bool success = m_currentStorage->storeVector(filePath, vector);
promise->addResult(success);
}
promise->finish();
});
return promise->future();
}
std::optional<RAGVector> RAGManager::loadVectorFromStorage(
ProjectExplorer::Project *project, const QString &filePath)
{
ensureStorageForProject(project);
if (!m_currentStorage) {
return std::nullopt;
}
return m_currentStorage->getVector(filePath);
}
QStringList RAGManager::getStoredFiles(ProjectExplorer::Project *project) const
{
if (m_currentProject != project || !m_currentStorage) {
auto tempStorage = RAGStorage(getStoragePath(project), nullptr);
if (!tempStorage.init()) {
return {};
}
return tempStorage.getAllFiles();
}
return m_currentStorage->getAllFiles();
}
bool RAGManager::isFileStorageOutdated(
ProjectExplorer::Project *project, const QString &filePath) const
{
if (m_currentProject != project || !m_currentStorage) {
auto tempStorage = RAGStorage(getStoragePath(project), nullptr);
if (!tempStorage.init()) {
return true;
}
return tempStorage.needsUpdate(filePath);
}
return m_currentStorage->needsUpdate(filePath);
}
} // namespace QodeAssist::Context

72
context/RAGManager.hpp Normal file
View File

@ -0,0 +1,72 @@
/*
* Copyright (C) 2024 Petr Mironychev
*
* This file is part of QodeAssist.
*
* QodeAssist is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* QodeAssist is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with QodeAssist. If not, see <https://www.gnu.org/licenses/>.
*/
#pragma once
#include <memory>
#include <QFuture>
#include <QObject>
#include <QString>
#include "RAGStorage.hpp"
#include "RAGVectorizer.hpp"
#include <RAGData.hpp>
namespace ProjectExplorer {
class Project;
}
namespace QodeAssist::Context {
class RAGManager : public QObject
{
Q_OBJECT
public:
static RAGManager &instance();
QFuture<void> processFiles(ProjectExplorer::Project *project, const QStringList &filePaths);
std::optional<RAGVector> loadVectorFromStorage(
ProjectExplorer::Project *project, const QString &filePath);
QStringList getStoredFiles(ProjectExplorer::Project *project) const;
bool isFileStorageOutdated(ProjectExplorer::Project *project, const QString &filePath) const;
signals:
void vectorizationProgress(int processed, int total);
void vectorizationFinished();
private:
RAGManager(QObject *parent = nullptr);
~RAGManager();
QFuture<bool> processFile(ProjectExplorer::Project *project, const QString &filePath);
void processNextBatch(
std::shared_ptr<QPromise<void>> promise,
ProjectExplorer::Project *project,
const QStringList &files,
int startIndex,
int batchSize);
void ensureStorageForProject(ProjectExplorer::Project *project);
QString getStoragePath(ProjectExplorer::Project *project) const;
std::unique_ptr<RAGVectorizer> m_vectorizer;
std::unique_ptr<RAGStorage> m_currentStorage;
ProjectExplorer::Project *m_currentProject{nullptr};
};
} // namespace QodeAssist::Context

165
context/RAGStorage.cpp Normal file
View File

@ -0,0 +1,165 @@
/*
* Copyright (C) 2024 Petr Mironychev
*
* This file is part of QodeAssist.
*
* QodeAssist is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* QodeAssist is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with QodeAssist. If not, see <https://www.gnu.org/licenses/>.
*/
#include "RAGStorage.hpp"
#include <QDir>
#include <QFileInfo>
#include <QSqlError>
#include <QSqlQuery>
namespace QodeAssist::Context {
RAGStorage::RAGStorage(const QString &dbPath, QObject *parent)
: QObject(parent)
, m_dbPath(dbPath)
{}
RAGStorage::~RAGStorage()
{
if (m_db.isOpen()) {
m_db.close();
}
}
bool RAGStorage::init()
{
if (!openDatabase()) {
return false;
}
return createTables();
}
bool RAGStorage::openDatabase()
{
QDir dir(QFileInfo(m_dbPath).absolutePath());
if (!dir.exists()) {
dir.mkpath(".");
}
m_db = QSqlDatabase::addDatabase("QSQLITE", "rag_storage");
m_db.setDatabaseName(m_dbPath);
return m_db.open();
}
bool RAGStorage::createTables()
{
QSqlQuery query(m_db);
return query.exec("CREATE TABLE IF NOT EXISTS file_vectors ("
"id INTEGER PRIMARY KEY AUTOINCREMENT,"
"file_path TEXT UNIQUE NOT NULL,"
"vector_data BLOB NOT NULL,"
"last_modified DATETIME NOT NULL,"
"created_at DATETIME DEFAULT CURRENT_TIMESTAMP,"
"updated_at DATETIME DEFAULT CURRENT_TIMESTAMP"
")");
}
bool RAGStorage::storeVector(const QString &filePath, const RAGVector &vector)
{
QSqlQuery query(m_db);
query.prepare("INSERT INTO file_vectors (file_path, vector_data, last_modified) "
"VALUES (:path, :vector, :modified)");
query.bindValue(":path", filePath);
query.bindValue(":vector", vectorToBlob(vector));
query.bindValue(":modified", getFileLastModified(filePath));
return query.exec();
}
bool RAGStorage::updateVector(const QString &filePath, const RAGVector &vector)
{
QSqlQuery query(m_db);
query.prepare("UPDATE file_vectors "
"SET vector_data = :vector, last_modified = :modified, "
"updated_at = CURRENT_TIMESTAMP "
"WHERE file_path = :path");
query.bindValue(":vector", vectorToBlob(vector));
query.bindValue(":modified", getFileLastModified(filePath));
query.bindValue(":path", filePath);
return query.exec();
}
std::optional<RAGVector> RAGStorage::getVector(const QString &filePath)
{
QSqlQuery query(m_db);
query.prepare("SELECT vector_data FROM file_vectors WHERE file_path = :path");
query.bindValue(":path", filePath);
if (query.exec() && query.next()) {
return blobToVector(query.value(0).toByteArray());
}
return std::nullopt;
}
bool RAGStorage::needsUpdate(const QString &filePath)
{
QSqlQuery query(m_db);
query.prepare("SELECT last_modified FROM file_vectors WHERE file_path = :path");
query.bindValue(":path", filePath);
if (query.exec() && query.next()) {
QDateTime storedTime = query.value(0).toDateTime();
return storedTime < getFileLastModified(filePath);
}
return true;
}
QStringList RAGStorage::getAllFiles()
{
QStringList files;
QSqlQuery query(m_db);
if (query.exec("SELECT file_path FROM file_vectors")) {
while (query.next()) {
files << query.value(0).toString();
}
}
return files;
}
QDateTime RAGStorage::getFileLastModified(const QString &filePath)
{
return QFileInfo(filePath).lastModified();
}
RAGVector RAGStorage::blobToVector(const QByteArray &blob)
{
RAGVector vector;
const float *data = reinterpret_cast<const float *>(blob.constData());
size_t size = blob.size() / sizeof(float);
vector.assign(data, data + size);
return vector;
}
QByteArray RAGStorage::vectorToBlob(const RAGVector &vector)
{
return QByteArray(reinterpret_cast<const char *>(vector.data()), vector.size() * sizeof(float));
}
QString RAGStorage::dbPath() const
{
return m_dbPath;
}
} // namespace QodeAssist::Context

58
context/RAGStorage.hpp Normal file
View File

@ -0,0 +1,58 @@
/*
* Copyright (C) 2024 Petr Mironychev
*
* This file is part of QodeAssist.
*
* QodeAssist is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* QodeAssist is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with QodeAssist. If not, see <https://www.gnu.org/licenses/>.
*/
#pragma once
#include <QDateTime>
#include <QObject>
#include <QSqlDatabase>
#include <QString>
#include <RAGData.hpp>
namespace QodeAssist::Context {
class RAGStorage : public QObject
{
Q_OBJECT
public:
explicit RAGStorage(const QString &dbPath, QObject *parent = nullptr);
~RAGStorage();
bool init();
bool storeVector(const QString &filePath, const RAGVector &vector);
bool updateVector(const QString &filePath, const RAGVector &vector);
std::optional<RAGVector> getVector(const QString &filePath);
bool needsUpdate(const QString &filePath);
QStringList getAllFiles();
QString dbPath() const;
private:
bool createTables();
bool openDatabase();
QDateTime getFileLastModified(const QString &filePath);
RAGVector blobToVector(const QByteArray &blob);
QByteArray vectorToBlob(const RAGVector &vector);
QSqlDatabase m_db;
QString m_dbPath;
};
} // namespace QodeAssist::Context

82
context/RAGVectorizer.cpp Normal file
View File

@ -0,0 +1,82 @@
/*
* Copyright (C) 2024 Petr Mironychev
*
* This file is part of QodeAssist.
*
* QodeAssist is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* QodeAssist is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with QodeAssist. If not, see <https://www.gnu.org/licenses/>.
*/
#include "RAGVectorizer.hpp"
#include <QJsonArray>
#include <QJsonDocument>
#include <QJsonObject>
#include <QNetworkReply>
namespace QodeAssist::Context {
RAGVectorizer::RAGVectorizer(const QString &providerUrl,
const QString &modelName,
QObject *parent)
: QObject(parent)
, m_network(new QNetworkAccessManager(this))
, m_embedProviderUrl(providerUrl)
, m_model(modelName)
{}
RAGVectorizer::~RAGVectorizer() {}
QJsonObject RAGVectorizer::prepareEmbeddingRequest(const QString &text) const
{
return QJsonObject{{"model", m_model}, {"prompt", text}};
}
RAGVector RAGVectorizer::parseEmbeddingResponse(const QByteArray &response) const
{
QJsonDocument doc = QJsonDocument::fromJson(response);
QJsonArray array = doc.object()["embedding"].toArray();
RAGVector result;
result.reserve(array.size());
for (const auto &value : array) {
result.push_back(value.toDouble());
}
return result;
}
QFuture<RAGVector> RAGVectorizer::vectorizeText(const QString &text)
{
auto promise = std::make_shared<QPromise<RAGVector>>();
promise->start();
QNetworkRequest request(QUrl(m_embedProviderUrl + "/api/embeddings"));
request.setHeader(QNetworkRequest::ContentTypeHeader, "application/json");
auto reply = m_network->post(request, QJsonDocument(prepareEmbeddingRequest(text)).toJson());
connect(reply, &QNetworkReply::finished, this, [promise, reply, this]() {
if (reply->error() == QNetworkReply::NoError) {
promise->addResult(parseEmbeddingResponse(reply->readAll()));
} else {
// TODO check error setException
promise->addResult(RAGVector());
}
promise->finish();
reply->deleteLater();
});
return promise->future();
}
} // namespace QodeAssist::Context

50
context/RAGVectorizer.hpp Normal file
View File

@ -0,0 +1,50 @@
/*
* Copyright (C) 2024 Petr Mironychev
*
* This file is part of QodeAssist.
*
* QodeAssist is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* QodeAssist is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with QodeAssist. If not, see <https://www.gnu.org/licenses/>.
*/
#pragma once
#include <QFuture>
#include <QNetworkAccessManager>
#include <QObject>
#include <RAGData.hpp>
namespace QodeAssist::Context {
class RAGVectorizer : public QObject
{
Q_OBJECT
public:
explicit RAGVectorizer(const QString &providerUrl = "http://localhost:11434",
const QString &modelName = "all-minilm",
QObject *parent = nullptr);
~RAGVectorizer();
QFuture<RAGVector> vectorizeText(const QString &text);
private:
QJsonObject prepareEmbeddingRequest(const QString &text) const;
RAGVector parseEmbeddingResponse(const QByteArray &response) const;
QNetworkAccessManager *m_network;
QString m_embedProviderUrl;
QString m_model;
};
} // namespace QodeAssist::Context