Add basic chat widgets and functionality

This commit is contained in:
Petr Mironychev
2024-09-15 01:58:56 +02:00
parent 6e56646b4c
commit 04c44f5916
37 changed files with 1422 additions and 369 deletions

View File

@ -29,16 +29,18 @@
#include "LLMProvidersManager.hpp"
#include "PromptTemplateManager.hpp"
#include "QodeAssistUtils.hpp"
#include "core/ChangesManager.h"
#include "settings/ContextSettings.hpp"
#include "core/LLMRequestConfig.hpp"
#include "settings/GeneralSettings.hpp"
namespace QodeAssist {
LLMClientInterface::LLMClientInterface()
: m_manager(new QNetworkAccessManager(this))
: m_requestHandler(this)
{
updateProvider();
connect(&m_requestHandler,
&LLMRequestHandler::completionReceived,
this,
&LLMClientInterface::sendCompletionToClient);
}
Utils::FilePath LLMClientInterface::serverDeviceTemplate() const
@ -53,8 +55,6 @@ void LLMClientInterface::startImpl()
void LLMClientInterface::sendData(const QByteArray &data)
{
updateProvider();
QJsonDocument doc = QJsonDocument::fromJson(data);
if (!doc.isObject())
return;
@ -86,87 +86,13 @@ void LLMClientInterface::sendData(const QByteArray &data)
void LLMClientInterface::handleCancelRequest(const QJsonObject &request)
{
QString id = request["params"].toObject()["id"].toString();
if (m_activeRequests.contains(id)) {
m_activeRequests[id]->abort();
m_activeRequests.remove(id);
if (m_requestHandler.cancelRequest(id)) {
logMessage(QString("Request %1 cancelled successfully").arg(id));
} else {
logMessage(QString("Request %1 not found").arg(id));
}
}
bool LLMClientInterface::processSingleLineCompletion(QNetworkReply *reply,
const QJsonObject &request,
const QString &accumulatedCompletion)
{
int newlinePos = accumulatedCompletion.indexOf('\n');
if (newlinePos != -1) {
QString singleLineCompletion = accumulatedCompletion.left(newlinePos).trimmed();
singleLineCompletion = removeStopWords(singleLineCompletion);
QJsonObject position = request["params"].toObject()["doc"].toObject()["position"].toObject();
sendCompletionToClient(singleLineCompletion, request, position, true);
m_accumulatedResponses.remove(reply);
reply->abort();
return true;
}
return false;
}
QString LLMClientInterface::сontextBefore(TextEditor::TextEditorWidget *widget,
int lineNumber,
int cursorPosition)
{
if (!widget)
return QString();
DocumentContextReader reader(widget->textDocument());
const auto &copyright = reader.copyrightInfo();
logMessage(QString{"Line Number: %1"}.arg(lineNumber));
logMessage(QString("Copyright found %1 %2").arg(copyright.found).arg(copyright.endLine));
if (lineNumber < reader.findCopyright().endLine)
return QString();
QString contextBefore;
if (Settings::contextSettings().readFullFile()) {
contextBefore = reader.readWholeFileBefore(lineNumber, cursorPosition);
} else {
contextBefore
= reader.getContextBefore(lineNumber,
cursorPosition,
Settings::contextSettings().readStringsBeforeCursor());
}
return contextBefore;
}
QString LLMClientInterface::сontextAfter(TextEditor::TextEditorWidget *widget,
int lineNumber,
int cursorPosition)
{
if (!widget)
return QString();
DocumentContextReader reader(widget->textDocument());
if (lineNumber < reader.findCopyright().endLine)
return QString();
QString contextAfter;
if (Settings::contextSettings().readFullFile()) {
contextAfter = reader.readWholeFileAfter(lineNumber, cursorPosition);
} else {
contextAfter = reader.getContextAfter(lineNumber,
cursorPosition,
Settings::contextSettings().readStringsAfterCursor());
}
return contextAfter;
}
void LLMClientInterface::handleInitialize(const QJsonObject &request)
{
QJsonObject response;
@ -217,40 +143,26 @@ void LLMClientInterface::handleExit(const QJsonObject &request)
emit finished();
}
void LLMClientInterface::handleLLMResponse(QNetworkReply *reply, const QJsonObject &request)
void LLMClientInterface::handleCompletion(const QJsonObject &request)
{
QString &accumulatedResponse = m_accumulatedResponses[reply];
auto updatedContext = prepareContext(request);
auto &templateManager = PromptTemplateManager::instance();
const Templates::PromptTemplate *currentTemplate = templateManager.getCurrentTemplate();
LLMConfig config;
config.requestType = RequestType::Fim;
config.provider = LLMProvidersManager::instance().getCurrentFimProvider();
config.promptTemplate = PromptTemplateManager::instance().getCurrentFimTemplate();
config.url = QUrl(QString("%1%2").arg(Settings::generalSettings().url(),
Settings::generalSettings().endPoint()));
auto &providerManager = LLMProvidersManager::instance();
bool isComplete = providerManager.getCurrentProvider()->handleResponse(reply,
accumulatedResponse);
config.providerRequest = {{"model", Settings::generalSettings().modelName.value()},
{"stream", true},
{"stop",
QJsonArray::fromStringList(config.promptTemplate->stopWords())}};
QJsonObject position = request["params"].toObject()["doc"].toObject()["position"].toObject();
config.promptTemplate->prepareRequest(config.providerRequest, updatedContext);
config.provider->prepareRequest(config.providerRequest);
if (!Settings::generalSettings().multiLineCompletion()
&& processSingleLineCompletion(reply, request, accumulatedResponse)) {
return;
}
if (isComplete || reply->isFinished()) {
if (isComplete) {
auto cleanedCompletion = removeStopWords(accumulatedResponse);
sendCompletionToClient(cleanedCompletion, request, position, true);
} else {
handleCompletion(request, accumulatedResponse);
}
m_accumulatedResponses.remove(reply);
}
}
void LLMClientInterface::handleCompletion(const QJsonObject &request,
const QStringView &accumulatedCompletion)
{
auto updatedContext = prepareContext(request, accumulatedCompletion);
sendLLMRequest(request, updatedContext);
m_requestHandler.sendLLMRequest(config, request);
}
ContextData LLMClientInterface::prepareContext(const QJsonObject &request,
@ -273,39 +185,16 @@ ContextData LLMClientInterface::prepareContext(const QJsonObject &request,
int cursorPosition = position["character"].toInt();
int lineNumber = position["line"].toInt();
auto textEditor = TextEditor::BaseTextEditor::currentTextEditor();
TextEditor::TextEditorWidget *widget = textEditor->editorWidget();
DocumentContextReader reader(widget->textDocument());
QString recentChanges = ChangesManager::instance().getRecentChangesContext(textDocument);
QString contextBefore = сontextBefore(widget, lineNumber, cursorPosition);
QString contextAfter = сontextAfter(widget, lineNumber, cursorPosition);
QString instructions
= QString("%1%2%3").arg(Settings::contextSettings().useSpecificInstructions()
? reader.getSpecificInstructions()
: QString(),
Settings::contextSettings().useFilePathInContext()
? reader.getLanguageAndFileInfo()
: QString(),
Settings::contextSettings().useProjectChangesCache() ? recentChanges
: QString());
return {QString("%1%2").arg(contextBefore, accumulatedCompletion), contextAfter, instructions};
}
void LLMClientInterface::updateProvider()
{
m_serverUrl = QUrl(QString("%1%2").arg(Settings::generalSettings().url(),
Settings::generalSettings().endPoint()));
DocumentContextReader reader(textDocument);
return reader.prepareContext(lineNumber, cursorPosition);
}
void LLMClientInterface::sendCompletionToClient(const QString &completion,
const QJsonObject &request,
const QJsonObject &position,
bool isComplete)
{
QJsonObject position = request["params"].toObject()["doc"].toObject()["position"].toObject();
QJsonObject response;
response["jsonrpc"] = "2.0";
response[LanguageServerProtocol::idKey] = request["id"];
@ -337,69 +226,6 @@ void LLMClientInterface::sendCompletionToClient(const QString &completion,
emit messageReceived(LanguageServerProtocol::JsonRpcMessage(response));
}
void LLMClientInterface::sendLLMRequest(const QJsonObject &request, const ContextData &prompt)
{
QJsonObject providerRequest = {{"model", Settings::generalSettings().modelName.value()},
{"stream", true}};
auto currentTemplate = PromptTemplateManager::instance().getCurrentTemplate();
currentTemplate->prepareRequest(providerRequest, prompt);
auto &providerManager = LLMProvidersManager::instance();
providerManager.getCurrentProvider()->prepareRequest(providerRequest);
logMessage(QString("Sending request to llm: \nurl: %1\nRequest body:\n%2")
.arg(m_serverUrl.toString(),
QString::fromUtf8(
QJsonDocument(providerRequest).toJson(QJsonDocument::Indented))));
QNetworkRequest networkRequest(m_serverUrl);
networkRequest.setHeader(QNetworkRequest::ContentTypeHeader, "application/json");
if (providerRequest.contains("api_key")) {
QString apiKey = providerRequest["api_key"].toString();
networkRequest.setRawHeader("Authorization", QString("Bearer %1").arg(apiKey).toUtf8());
providerRequest.remove("api_key");
}
QNetworkReply *reply = m_manager->post(networkRequest, QJsonDocument(providerRequest).toJson());
if (!reply) {
logMessage("Error: Failed to create network reply");
return;
}
QString requestId = request["id"].toString();
m_activeRequests[requestId] = reply;
connect(reply, &QNetworkReply::readyRead, this, [this, reply, request]() {
handleLLMResponse(reply, request);
});
connect(reply, &QNetworkReply::finished, this, [this, reply, requestId]() {
reply->deleteLater();
m_activeRequests.remove(requestId);
if (reply->error() != QNetworkReply::NoError) {
logMessage(QString("Error in QodeAssist request: %1").arg(reply->errorString()));
} else {
logMessage("Request finished successfully");
}
});
}
QString LLMClientInterface::removeStopWords(const QStringView &completion)
{
QString filteredCompletion = completion.toString();
auto currentTemplate = PromptTemplateManager::instance().getCurrentTemplate();
QStringList stopWords = currentTemplate->stopWords();
for (const QString &stopWord : stopWords) {
filteredCompletion = filteredCompletion.replace(stopWord, "");
}
return filteredCompletion;
}
void LLMClientInterface::startTimeMeasurement(const QString &requestId)
{
m_requestStartTimes[requestId] = QDateTime::currentMSecsSinceEpoch();