feat: add enhancedSearch

This commit is contained in:
Petr Mironychev 2025-01-27 22:02:50 +01:00
parent 09c38c8b0e
commit 77a03d42ed
8 changed files with 591 additions and 26 deletions

View File

@ -10,6 +10,8 @@ add_library(Context STATIC
RAGData.hpp
RAGVectorizer.hpp RAGVectorizer.cpp
RAGSimilaritySearch.hpp RAGSimilaritySearch.cpp
RAGPreprocessor.hpp RAGPreprocessor.cpp
EnhancedRAGSimilaritySearch.hpp EnhancedRAGSimilaritySearch.cpp
)
target_link_libraries(Context

View File

@ -0,0 +1,265 @@
#include "EnhancedRAGSimilaritySearch.hpp"
#include <QSet>
namespace QodeAssist::Context {
// Static regex getters
const QRegularExpression &EnhancedRAGSimilaritySearch::getNamespaceRegex()
{
static const QRegularExpression regex(R"(namespace\s+(?:\w+\s*::\s*)*\w+\s*\{)");
return regex;
}
const QRegularExpression &EnhancedRAGSimilaritySearch::getClassRegex()
{
static const QRegularExpression regex(
R"((?:template\s*<[^>]*>\s*)?(?:class|struct)\s+(\w+)\s*(?:final\s*)?(?::\s*(?:public|protected|private)\s+\w+(?:\s*,\s*(?:public|protected|private)\s+\w+)*\s*)?{)");
return regex;
}
const QRegularExpression &EnhancedRAGSimilaritySearch::getFunctionRegex()
{
static const QRegularExpression regex(
R"((?:virtual\s+)?(?:static\s+)?(?:inline\s+)?(?:explicit\s+)?(?:constexpr\s+)?(?:[\w:]+\s+)?(?:\w+\s*::\s*)*\w+\s*\([^)]*\)\s*(?:const\s*)?(?:noexcept\s*)?(?:override\s*)?(?:final\s*)?(?:=\s*0\s*)?(?:=\s*default\s*)?(?:=\s*delete\s*)?(?:\s*->.*?)?\s*{)");
return regex;
}
const QRegularExpression &EnhancedRAGSimilaritySearch::getTemplateRegex()
{
static const QRegularExpression regex(R"(template\s*<[^>]*>\s*(?:class|struct|typename)\s+\w+)");
return regex;
}
// Cache getters
QCache<QString, EnhancedRAGSimilaritySearch::SimilarityScore> &
EnhancedRAGSimilaritySearch::getScoreCache()
{
static QCache<QString, SimilarityScore> cache(1000); // Cache size of 1000 entries
return cache;
}
QCache<QString, QStringList> &EnhancedRAGSimilaritySearch::getStructureCache()
{
static QCache<QString, QStringList> cache(500); // Cache size of 500 entries
return cache;
}
// Main public interface
EnhancedRAGSimilaritySearch::SimilarityScore EnhancedRAGSimilaritySearch::calculateSimilarity(
const RAGVector &v1, const RAGVector &v2, const QString &code1, const QString &code2)
{
// Generate cache key based on content hashes
QString cacheKey = QString("%1_%2").arg(qHash(code1)).arg(qHash(code2));
// Check cache first
auto &scoreCache = getScoreCache();
if (auto *cached = scoreCache.object(cacheKey)) {
return *cached;
}
// Calculate new similarity score
SimilarityScore score = calculateSimilarityInternal(v1, v2, code1, code2);
// Cache the result
scoreCache.insert(cacheKey, new SimilarityScore(score));
return score;
}
// Internal implementation
EnhancedRAGSimilaritySearch::SimilarityScore EnhancedRAGSimilaritySearch::calculateSimilarityInternal(
const RAGVector &v1, const RAGVector &v2, const QString &code1, const QString &code2)
{
if (v1.empty() || v2.empty()) {
LOG_MESSAGE("Warning: Empty vectors in similarity calculation");
return SimilarityScore(0.0f, 0.0f, 0.0f);
}
if (v1.size() != v2.size()) {
LOG_MESSAGE(QString("Vector size mismatch: %1 vs %2").arg(v1.size()).arg(v2.size()));
return SimilarityScore(0.0f, 0.0f, 0.0f);
}
// Calculate semantic similarity using vector embeddings
float semantic_similarity = 0.0f;
#if defined(__SSE__) || defined(_M_X64) || defined(_M_AMD64)
if (v1.size() >= 4) { // Use SSE for vectors of 4 or more elements
semantic_similarity = calculateCosineSimilaritySSE(v1, v2);
} else {
semantic_similarity = calculateCosineSimilarity(v1, v2);
}
#else
semantic_similarity = calculateCosineSimilarity(v1, v2);
#endif
// If semantic similarity is very low, skip structural analysis
if (semantic_similarity < 0.0001f) {
return SimilarityScore(0.0f, 0.0f, 0.0f);
}
// Calculate structural similarity
float structural_similarity = calculateStructuralSimilarity(code1, code2);
// Calculate combined score with dynamic weights
float semantic_weight = 0.7f;
const int large_file_threshold = 10000;
if (code1.size() > large_file_threshold || code2.size() > large_file_threshold) {
semantic_weight = 0.8f; // Increase semantic weight for large files
}
float combined_score = semantic_weight * semantic_similarity
+ (1.0f - semantic_weight) * structural_similarity;
return SimilarityScore(semantic_similarity, structural_similarity, combined_score);
}
float EnhancedRAGSimilaritySearch::calculateCosineSimilarity(const RAGVector &v1, const RAGVector &v2)
{
float dotProduct = 0.0f;
float norm1 = 0.0f;
float norm2 = 0.0f;
for (size_t i = 0; i < v1.size(); ++i) {
dotProduct += v1[i] * v2[i];
norm1 += v1[i] * v1[i];
norm2 += v2[i] * v2[i];
}
norm1 = std::sqrt(norm1);
norm2 = std::sqrt(norm2);
if (norm1 == 0.0f || norm2 == 0.0f) {
return 0.0f;
}
return dotProduct / (norm1 * norm2);
}
#if defined(__SSE__) || defined(_M_X64) || defined(_M_AMD64)
float EnhancedRAGSimilaritySearch::calculateCosineSimilaritySSE(
const RAGVector &v1, const RAGVector &v2)
{
const float *p1 = v1.data();
const float *p2 = v2.data();
const size_t size = v1.size();
const size_t alignedSize = size & ~3ULL; // Round down to multiple of 4
__m128 sum = _mm_setzero_ps();
__m128 norm1 = _mm_setzero_ps();
__m128 norm2 = _mm_setzero_ps();
// Process 4 elements at a time using SSE
for (size_t i = 0; i < alignedSize; i += 4) {
__m128 v1_vec = _mm_loadu_ps(p1 + i); // Use unaligned load for safety
__m128 v2_vec = _mm_loadu_ps(p2 + i);
sum = _mm_add_ps(sum, _mm_mul_ps(v1_vec, v2_vec));
norm1 = _mm_add_ps(norm1, _mm_mul_ps(v1_vec, v1_vec));
norm2 = _mm_add_ps(norm2, _mm_mul_ps(v2_vec, v2_vec));
}
float dotProduct = horizontalSum(sum);
float n1 = std::sqrt(horizontalSum(norm1));
float n2 = std::sqrt(horizontalSum(norm2));
// Process remaining elements
for (size_t i = alignedSize; i < size; ++i) {
dotProduct += v1[i] * v2[i];
n1 += v1[i] * v1[i];
n2 += v2[i] * v2[i];
}
if (n1 == 0.0f || n2 == 0.0f) {
return 0.0f;
}
return dotProduct / (std::sqrt(n1) * std::sqrt(n2));
}
float EnhancedRAGSimilaritySearch::horizontalSum(__m128 x)
{
__m128 shuf = _mm_shuffle_ps(x, x, _MM_SHUFFLE(2, 3, 0, 1));
__m128 sums = _mm_add_ps(x, shuf);
shuf = _mm_movehl_ps(shuf, sums);
sums = _mm_add_ss(sums, shuf);
return _mm_cvtss_f32(sums);
}
#endif
float EnhancedRAGSimilaritySearch::calculateStructuralSimilarity(
const QString &code1, const QString &code2)
{
QStringList structures1 = extractStructures(code1);
QStringList structures2 = extractStructures(code2);
return calculateJaccardSimilarity(structures1, structures2);
}
QStringList EnhancedRAGSimilaritySearch::extractStructures(const QString &code)
{
// Check cache first
auto &structureCache = getStructureCache();
QString cacheKey = QString::number(qHash(code));
if (auto *cached = structureCache.object(cacheKey)) {
return *cached;
}
QStringList structures;
structures.reserve(100); // Reserve space for typical file
// Extract namespaces
auto namespaceMatches = getNamespaceRegex().globalMatch(code);
while (namespaceMatches.hasNext()) {
structures.append(namespaceMatches.next().captured().trimmed());
}
// Extract classes
auto classMatches = getClassRegex().globalMatch(code);
while (classMatches.hasNext()) {
structures.append(classMatches.next().captured().trimmed());
}
// Extract functions
auto functionMatches = getFunctionRegex().globalMatch(code);
while (functionMatches.hasNext()) {
structures.append(functionMatches.next().captured().trimmed());
}
// Extract templates
auto templateMatches = getTemplateRegex().globalMatch(code);
while (templateMatches.hasNext()) {
structures.append(templateMatches.next().captured().trimmed());
}
// Cache the result
structureCache.insert(cacheKey, new QStringList(structures));
return structures;
}
float EnhancedRAGSimilaritySearch::calculateJaccardSimilarity(
const QStringList &set1, const QStringList &set2)
{
if (set1.isEmpty() && set2.isEmpty()) {
return 1.0f; // Пустые множества считаем идентичными
}
if (set1.isEmpty() || set2.isEmpty()) {
return 0.0f;
}
QSet<QString> set1Unique = QSet<QString>(set1.begin(), set1.end());
QSet<QString> set2Unique = QSet<QString>(set2.begin(), set2.end());
QSet<QString> intersection = set1Unique;
intersection.intersect(set2Unique);
QSet<QString> union_set = set1Unique;
union_set.unite(set2Unique);
return static_cast<float>(intersection.size()) / union_set.size();
}
} // namespace QodeAssist::Context

View File

@ -0,0 +1,74 @@
#pragma once
#include <QCache>
#include <QHash>
#include <QRegularExpression>
#include <QString>
#include <QStringList>
#include <QtGlobal>
#include <algorithm>
#include <cmath>
#include <optional>
#include <vector>
#if defined(__SSE__) || defined(_M_X64) || defined(_M_AMD64)
#include <emmintrin.h>
#include <xmmintrin.h>
#endif
#include "RAGData.hpp"
#include "logger/Logger.hpp"
namespace QodeAssist::Context {
class EnhancedRAGSimilaritySearch
{
public:
struct SimilarityScore
{
float semantic_similarity{0.0f};
float structural_similarity{0.0f};
float combined_score{0.0f};
SimilarityScore() = default;
SimilarityScore(float sem, float str, float comb)
: semantic_similarity(sem)
, structural_similarity(str)
, combined_score(comb)
{}
};
static SimilarityScore calculateSimilarity(
const RAGVector &v1, const RAGVector &v2, const QString &code1, const QString &code2);
private:
static SimilarityScore calculateSimilarityInternal(
const RAGVector &v1, const RAGVector &v2, const QString &code1, const QString &code2);
static float calculateCosineSimilarity(const RAGVector &v1, const RAGVector &v2);
#if defined(__SSE__) || defined(_M_X64) || defined(_M_AMD64)
static float calculateCosineSimilaritySSE(const RAGVector &v1, const RAGVector &v2);
static float horizontalSum(__m128 x);
#endif
static float calculateStructuralSimilarity(const QString &code1, const QString &code2);
static QStringList extractStructures(const QString &code);
static float calculateJaccardSimilarity(const QStringList &set1, const QStringList &set2);
static const QRegularExpression &getNamespaceRegex();
static const QRegularExpression &getClassRegex();
static const QRegularExpression &getFunctionRegex();
static const QRegularExpression &getTemplateRegex();
// Cache for similarity scores
static QCache<QString, SimilarityScore> &getScoreCache();
// Cache for extracted structures
static QCache<QString, QStringList> &getStructureCache();
EnhancedRAGSimilaritySearch() = delete; // Prevent instantiation
};
} // namespace QodeAssist::Context

View File

@ -27,6 +27,9 @@
#include <QtConcurrent>
#include <queue>
#include <EnhancedRAGSimilaritySearch.hpp>
#include <RAGPreprocessor.hpp>
namespace QodeAssist::Context {
RAGManager &RAGManager::instance()
@ -42,12 +45,12 @@ RAGManager::RAGManager(QObject *parent)
RAGManager::~RAGManager() {}
bool RAGManager::SearchResult::operator<(const SearchResult &other) const
{
if (cosineScore != other.cosineScore)
return cosineScore > other.cosineScore;
return l2Score < other.l2Score;
}
// bool RAGManager::SearchResult::operator<(const SearchResult &other) const
// {
// if (cosineScore != other.cosineScore)
// return cosineScore > other.cosineScore;
// return l2Score < other.l2Score;
// }
QString RAGManager::getStoragePath(ProjectExplorer::Project *project) const
{
@ -55,6 +58,26 @@ QString RAGManager::getStoragePath(ProjectExplorer::Project *project) const
.arg(Core::ICore::userResourcePath().toString(), project->displayName());
}
std::optional<QString> RAGManager::loadFileContent(const QString &filePath)
{
QFile file(filePath);
if (!file.open(QIODevice::ReadOnly | QIODevice::Text)) {
qDebug() << "ERROR: Failed to open file for reading:" << filePath
<< "Error:" << file.errorString();
return std::nullopt;
}
QFileInfo fileInfo(filePath);
qDebug() << "Loading content from file:" << fileInfo.fileName() << "Size:" << fileInfo.size()
<< "bytes";
QString content = QString::fromUtf8(file.readAll());
if (content.isEmpty()) {
qDebug() << "WARNING: Empty content read from file:" << filePath;
}
return content;
}
void RAGManager::ensureStorageForProject(ProjectExplorer::Project *project)
{
if (m_currentProject == project && m_currentStorage) {
@ -110,6 +133,93 @@ QFuture<void> RAGManager::processFiles(
return promise->future();
}
void RAGManager::searchSimilarDocuments(
const QString &text, ProjectExplorer::Project *project, int topK)
{
qDebug() << "\nStarting similarity search with parameters:";
qDebug() << "Query length:" << text.length();
qDebug() << "Project:" << project->displayName();
qDebug() << "Top K:" << topK;
// Предобработка текста запроса
QString processedText = RAGPreprocessor::preprocessCode(text);
qDebug() << "Preprocessed query length:" << processedText.length();
auto future = m_vectorizer->vectorizeText(processedText);
qDebug() << "Started query vectorization";
future.then([this, project, processedText, topK, text](const RAGVector &queryVector) {
if (queryVector.empty()) {
qDebug() << "ERROR: Query vectorization failed - empty vector";
return;
}
qDebug() << "Query vector generated, size:" << queryVector.size();
auto storedFiles = getStoredFiles(project);
qDebug() << "Found" << storedFiles.size() << "stored files to compare";
QList<SearchResult> results;
results.reserve(storedFiles.size());
int processedFiles = 0;
int skippedFiles = 0;
for (const auto &filePath : storedFiles) {
// Загружаем и обрабатываем содержимое файла
auto storedCode = loadFileContent(filePath);
if (!storedCode.has_value()) {
qDebug() << "ERROR: Failed to load content for file:" << filePath;
skippedFiles++;
continue;
}
// Получаем вектор из хранилища
auto storedVector = loadVectorFromStorage(project, filePath);
if (!storedVector.has_value()) {
qDebug() << "ERROR: Failed to load vector for file:" << filePath;
skippedFiles++;
continue;
}
// Предобработка содержимого файла
QString processedStoredCode = RAGPreprocessor::preprocessCode(storedCode.value());
// Используем улучшенное сравнение
auto similarity = EnhancedRAGSimilaritySearch::calculateSimilarity(
queryVector, storedVector.value(), processedText, processedStoredCode);
results.append(
{filePath,
similarity.semantic_similarity,
similarity.structural_similarity,
similarity.combined_score});
processedFiles++;
if (processedFiles % 100 == 0) {
qDebug() << "Processed" << processedFiles << "files...";
}
}
qDebug() << "\nSearch statistics:";
qDebug() << "Total files processed:" << processedFiles;
qDebug() << "Files skipped:" << skippedFiles;
qDebug() << "Total results before filtering:" << results.size();
// Оптимизированная сортировка топ K результатов
if (results.size() > topK) {
qDebug() << "Performing partial sort for top" << topK << "results";
std::partial_sort(results.begin(), results.begin() + topK, results.end());
results = results.mid(0, topK);
} else {
qDebug() << "Performing full sort for" << results.size() << "results";
std::sort(results.begin(), results.end());
}
qDebug() << "Sorting completed, logging final results...";
logSearchResults(results);
});
}
void RAGManager::processNextBatch(
std::shared_ptr<QPromise<void>> promise,
ProjectExplorer::Project *project,
@ -158,11 +268,14 @@ void RAGManager::processNextBatch(
QFuture<bool> RAGManager::processFile(ProjectExplorer::Project *project, const QString &filePath)
{
qDebug() << "Starting to process file:" << filePath;
auto promise = std::make_shared<QPromise<bool>>();
promise->start();
ensureStorageForProject(project);
if (!m_currentStorage) {
qDebug() << "ERROR: Storage not initialized for project" << project->displayName();
promise->addResult(false);
promise->finish();
return promise->future();
@ -170,6 +283,7 @@ QFuture<bool> RAGManager::processFile(ProjectExplorer::Project *project, const Q
QFile file(filePath);
if (!file.open(QIODevice::ReadOnly)) {
qDebug() << "ERROR: Failed to open file for reading:" << filePath;
promise->addResult(false);
promise->finish();
return promise->future();
@ -177,14 +291,29 @@ QFuture<bool> RAGManager::processFile(ProjectExplorer::Project *project, const Q
QFileInfo fileInfo(filePath);
QString fileName = fileInfo.fileName();
QString content = QString("// %1\n%2").arg(fileName, QString::fromUtf8(file.readAll()));
QString content = QString::fromUtf8(file.readAll());
auto vectorFuture = m_vectorizer->vectorizeText(content);
vectorFuture.then([promise, filePath, this](const RAGVector &vector) {
qDebug() << "File" << fileName << "read, content size:" << content.size() << "bytes";
// Предобработка контента
QString processedContent = RAGPreprocessor::preprocessCode(content);
qDebug() << "Preprocessed content size:" << processedContent.size() << "bytes";
auto vectorFuture = m_vectorizer->vectorizeText(processedContent);
qDebug() << "Started vectorization for file:" << fileName;
vectorFuture.then([promise, filePath, fileName, this](const RAGVector &vector) {
if (vector.empty()) {
qDebug() << "ERROR: Vectorization failed for file:" << fileName << "- empty vector";
promise->addResult(false);
} else {
qDebug() << "Vector generated for file:" << fileName << "size:" << vector.size();
bool success = m_currentStorage->storeVector(filePath, vector);
if (!success) {
qDebug() << "ERROR: Failed to store vector for file:" << fileName;
} else {
qDebug() << "Successfully stored vector for file:" << fileName;
}
promise->addResult(success);
}
promise->finish();
@ -273,22 +402,35 @@ QFuture<QList<RAGManager::SearchResult>> RAGManager::search(
return promise->future();
}
void RAGManager::searchSimilarDocuments(
const QString &text, ProjectExplorer::Project *project, int topK)
{
auto future = search(text, project, topK);
future.then([this](const QList<SearchResult> &results) { logSearchResults(results); });
}
// void RAGManager::searchSimilarDocuments(
// const QString &text, ProjectExplorer::Project *project, int topK)
// {
// auto future = search(text, project, topK);
// future.then([this](const QList<SearchResult> &results) { logSearchResults(results); });
// }
void RAGManager::logSearchResults(const QList<SearchResult> &results) const
{
qDebug() << QString("\nTop %1 similar documents:").arg(results.size());
qDebug() << "\n=== Search Results ===";
qDebug() << "Number of results:" << results.size();
for (const auto &result : results) {
qDebug() << QString("File: %1").arg(result.filePath);
qDebug() << QString(" Cosine Similarity: %1").arg(result.cosineScore);
qDebug() << QString(" L2 Distance: %1\n").arg(result.l2Score);
if (results.empty()) {
qDebug() << "No similar documents found.";
return;
}
for (int i = 0; i < results.size(); ++i) {
const auto &result = results[i];
QFileInfo fileInfo(result.filePath);
qDebug() << "\nResult #" << (i + 1);
qDebug() << "File:" << fileInfo.fileName();
qDebug() << "Full path:" << result.filePath;
qDebug() << "Semantic similarity:" << QString::number(result.semantic_similarity, 'f', 4);
qDebug() << "Structural similarity:"
<< QString::number(result.structural_similarity, 'f', 4);
qDebug() << "Combined score:" << QString::number(result.combined_score, 'f', 4);
}
qDebug() << "\n=== End of Results ===\n";
}
} // namespace QodeAssist::Context

View File

@ -40,13 +40,25 @@ class RAGManager : public QObject
public:
static RAGManager &instance();
// struct SearchResult
// {
// QString filePath;
// float l2Score;
// float cosineScore;
// bool operator<(const SearchResult &other) const;
// };
struct SearchResult
{
QString filePath;
float l2Score;
float cosineScore;
float semantic_similarity;
float structural_similarity;
float combined_score;
bool operator<(const SearchResult &other) const;
bool operator<(const SearchResult &other) const
{
return combined_score > other.combined_score;
}
};
// Process and vectorize files
@ -86,6 +98,7 @@ private:
std::unique_ptr<RAGVectorizer> m_vectorizer;
std::unique_ptr<RAGStorage> m_currentStorage;
ProjectExplorer::Project *m_currentProject{nullptr};
std::optional<QString> loadFileContent(const QString &filePath);
};
} // namespace QodeAssist::Context

View File

@ -0,0 +1,2 @@
#include "RAGPreprocessor.hpp"

View File

@ -0,0 +1,66 @@
#include <QRegularExpression>
#include <QString>
#include "Logger.hpp"
namespace QodeAssist::Context {
class RAGPreprocessor
{
public:
static const QRegularExpression &getLicenseRegex()
{
static const QRegularExpression regex(
R"((/\*[^*]*\*+(?:[^/*][^*]*\*+)*/)|//[^\n]*(?:\n|$))",
QRegularExpression::MultilineOption);
return regex;
}
static const QRegularExpression &getClassRegex()
{
static const QRegularExpression regex(
R"((?:template\s*<[^>]*>\s*)?(?:class|struct)\s+(\w+)\s*(?:final\s*)?(?::\s*(?:public|protected|private)\s+\w+(?:\s*,\s*(?:public|protected|private)\s+\w+)*\s*)?{)");
return regex;
}
static QString preprocessCode(const QString &code)
{
if (code.isEmpty()) {
return QString();
}
try {
// Прямое разделение без промежуточной копии
QStringList lines = code.split('\n', Qt::SkipEmptyParts);
return processLines(lines);
} catch (const std::exception &e) {
LOG_MESSAGE(QString("Error preprocessing code: %1").arg(e.what()));
return code; // Возвращаем оригинальный код в случае ошибки
}
}
private:
static QString processLines(const QStringList &lines)
{
const int estimatedAvgLength = 80; // Примерная средняя длина строки
QString result;
result.reserve(lines.size() * estimatedAvgLength);
for (const QString &line : lines) {
const QString trimmed = line.trimmed();
if (!trimmed.isEmpty()) {
result += trimmed;
result += QLatin1Char('\n');
}
}
// Убираем последний перенос строки, если он есть
if (result.endsWith('\n')) {
result.chop(1);
}
return result;
}
};
} // namespace QodeAssist::Context

View File

@ -31,9 +31,10 @@ class RAGVectorizer : public QObject
{
Q_OBJECT
public:
explicit RAGVectorizer(const QString &providerUrl = "http://localhost:11434",
const QString &modelName = "all-minilm",
QObject *parent = nullptr);
explicit RAGVectorizer(
const QString &providerUrl = "http://localhost:11434",
const QString &modelName = "all-minilm:33m-l12-v2-fp16",
QObject *parent = nullptr);
~RAGVectorizer();
QFuture<RAGVector> vectorizeText(const QString &text);