From 04c44f5916779426503e72b9f0b27aa4e78ef8bd Mon Sep 17 00:00:00 2001 From: Petr Mironychev <9195189+Palm1r@users.noreply.github.com> Date: Sun, 15 Sep 2024 01:58:56 +0200 Subject: [PATCH] Add basic chat widgets and functionality --- CMakeLists.txt | 9 +- DocumentContextReader.cpp | 53 ++++ DocumentContextReader.hpp | 9 + LLMClientInterface.cpp | 226 ++-------------- LLMClientInterface.hpp | 26 +- LLMProvidersManager.cpp | 60 +++-- LLMProvidersManager.hpp | 14 +- PromptTemplateManager.cpp | 53 +++- PromptTemplateManager.hpp | 24 +- QodeAssistConstants.hpp | 6 + QodeAssistUtils.hpp | 3 + chat/ChatClientInterface.cpp | 153 +++++++++++ chat/ChatClientInterface.hpp | 53 ++++ chat/ChatOutputPane.cpp | 95 +++++++ chat/ChatOutputPane.h | 52 ++++ chat/ChatWidget.cpp | 150 +++++++++++ chat/ChatWidget.h | 62 +++++ core/LLMRequestConfig.hpp | 40 +++ core/LLMRequestHandler.cpp | 162 ++++++++++++ core/LLMRequestHandler.hpp | 64 +++++ providers/LLMProvider.hpp | 1 + providers/LMStudioProvider.cpp | 9 +- providers/LMStudioProvider.hpp | 1 + providers/OllamaProvider.cpp | 49 +++- providers/OllamaProvider.hpp | 1 + providers/OpenAICompatProvider.cpp | 10 +- providers/OpenAICompatProvider.hpp | 1 + qodeassist.cpp | 17 +- settings/GeneralSettings.cpp | 243 ++++++++++++------ settings/GeneralSettings.hpp | 25 +- ...aTemplate.hpp => CodeLlamaFimTemplate.hpp} | 5 +- templates/CodeLlamaInstruct.hpp | 49 ++++ templates/CustomTemplate.hpp | 3 +- templates/DeepSeekCoderChatTemplate.hpp | 54 ++++ templates/DeepSeekCoderV2.hpp | 3 +- templates/PromptTemplate.hpp | 3 + templates/StarCoder2Template.hpp | 3 +- 37 files changed, 1422 insertions(+), 369 deletions(-) create mode 100644 chat/ChatClientInterface.cpp create mode 100644 chat/ChatClientInterface.hpp create mode 100644 chat/ChatOutputPane.cpp create mode 100644 chat/ChatOutputPane.h create mode 100644 chat/ChatWidget.cpp create mode 100644 chat/ChatWidget.h create mode 100644 core/LLMRequestConfig.hpp create mode 100644 core/LLMRequestHandler.cpp create mode 100644 core/LLMRequestHandler.hpp rename templates/{CodeLLamaTemplate.hpp => CodeLlamaFimTemplate.hpp} (88%) create mode 100644 templates/CodeLlamaInstruct.hpp create mode 100644 templates/DeepSeekCoderChatTemplate.hpp diff --git a/CMakeLists.txt b/CMakeLists.txt index f1e9092..5616d1e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -32,10 +32,12 @@ add_qtc_plugin(QodeAssist LLMClientInterface.hpp LLMClientInterface.cpp PromptTemplateManager.hpp PromptTemplateManager.cpp templates/PromptTemplate.hpp - templates/CodeLLamaTemplate.hpp + templates/CodeLlamaFimTemplate.hpp templates/StarCoder2Template.hpp templates/DeepSeekCoderV2.hpp templates/CustomTemplate.hpp + templates/DeepSeekCoderChatTemplate.hpp + templates/CodeLlamaInstruct.hpp providers/LLMProvider.hpp providers/OllamaProvider.hpp providers/OllamaProvider.cpp providers/LMStudioProvider.hpp providers/LMStudioProvider.cpp @@ -55,4 +57,9 @@ add_qtc_plugin(QodeAssist settings/PresetPromptsSettings.hpp settings/PresetPromptsSettings.cpp settings/SettingsUtils.hpp core/ChangesManager.h core/ChangesManager.cpp + core/LLMRequestHandler.hpp core/LLMRequestHandler.cpp + core/LLMRequestConfig.hpp + chat/ChatWidget.h chat/ChatWidget.cpp + chat/ChatOutputPane.h chat/ChatOutputPane.cpp + chat/ChatClientInterface.hpp chat/ChatClientInterface.cpp ) diff --git a/DocumentContextReader.cpp b/DocumentContextReader.cpp index 063162b..b3c2707 100644 --- a/DocumentContextReader.cpp +++ b/DocumentContextReader.cpp @@ -23,6 +23,7 @@ #include #include +#include "core/ChangesManager.h" #include "settings/ContextSettings.hpp" const QRegularExpression &getYearRegex() @@ -209,4 +210,56 @@ CopyrightInfo DocumentContextReader::copyrightInfo() const return m_copyrightInfo; } +ContextData DocumentContextReader::prepareContext(int lineNumber, int cursorPosition) const +{ + QString contextBefore = getContextBefore(lineNumber, cursorPosition); + QString contextAfter = getContextAfter(lineNumber, cursorPosition); + QString instructions = getInstructions(); + + return {contextBefore, contextAfter, instructions}; +} + +QString DocumentContextReader::getContextBefore(int lineNumber, int cursorPosition) const +{ + if (Settings::contextSettings().readFullFile()) { + return readWholeFileBefore(lineNumber, cursorPosition); + } else { + int effectiveStartLine; + int beforeCursor = Settings::contextSettings().readStringsBeforeCursor(); + if (m_copyrightInfo.found) { + effectiveStartLine = qMax(m_copyrightInfo.endLine + 1, lineNumber - beforeCursor); + } else { + effectiveStartLine = qMax(0, lineNumber - beforeCursor); + } + return getContextBetween(effectiveStartLine, lineNumber, cursorPosition); + } +} + +QString DocumentContextReader::getContextAfter(int lineNumber, int cursorPosition) const +{ + if (Settings::contextSettings().readFullFile()) { + return readWholeFileAfter(lineNumber, cursorPosition); + } else { + int endLine = qMin(m_document->blockCount() - 1, + lineNumber + Settings::contextSettings().readStringsAfterCursor()); + return getContextBetween(lineNumber + 1, endLine, -1); + } +} + +QString DocumentContextReader::getInstructions() const +{ + QString instructions; + + if (Settings::contextSettings().useSpecificInstructions()) + instructions += getSpecificInstructions(); + + if (Settings::contextSettings().useFilePathInContext()) + instructions += getLanguageAndFileInfo(); + + if (Settings::contextSettings().useProjectChangesCache()) + instructions += ChangesManager::instance().getRecentChangesContext(m_textDocument); + + return instructions; +} + } // namespace QodeAssist diff --git a/DocumentContextReader.hpp b/DocumentContextReader.hpp index 65ce53a..a3f9336 100644 --- a/DocumentContextReader.hpp +++ b/DocumentContextReader.hpp @@ -22,6 +22,8 @@ #include #include +#include "QodeAssistData.hpp" + namespace QodeAssist { struct CopyrightInfo @@ -48,6 +50,13 @@ public: CopyrightInfo copyrightInfo() const; + ContextData prepareContext(int lineNumber, int cursorPosition) const; + +private: + QString getContextBefore(int lineNumber, int cursorPosition) const; + QString getContextAfter(int lineNumber, int cursorPosition) const; + QString getInstructions() const; + private: TextEditor::TextDocument *m_textDocument; QTextDocument *m_document; diff --git a/LLMClientInterface.cpp b/LLMClientInterface.cpp index 029a70c..6dd25ae 100644 --- a/LLMClientInterface.cpp +++ b/LLMClientInterface.cpp @@ -29,16 +29,18 @@ #include "LLMProvidersManager.hpp" #include "PromptTemplateManager.hpp" #include "QodeAssistUtils.hpp" -#include "core/ChangesManager.h" -#include "settings/ContextSettings.hpp" +#include "core/LLMRequestConfig.hpp" #include "settings/GeneralSettings.hpp" namespace QodeAssist { LLMClientInterface::LLMClientInterface() - : m_manager(new QNetworkAccessManager(this)) + : m_requestHandler(this) { - updateProvider(); + connect(&m_requestHandler, + &LLMRequestHandler::completionReceived, + this, + &LLMClientInterface::sendCompletionToClient); } Utils::FilePath LLMClientInterface::serverDeviceTemplate() const @@ -53,8 +55,6 @@ void LLMClientInterface::startImpl() void LLMClientInterface::sendData(const QByteArray &data) { - updateProvider(); - QJsonDocument doc = QJsonDocument::fromJson(data); if (!doc.isObject()) return; @@ -86,87 +86,13 @@ void LLMClientInterface::sendData(const QByteArray &data) void LLMClientInterface::handleCancelRequest(const QJsonObject &request) { QString id = request["params"].toObject()["id"].toString(); - if (m_activeRequests.contains(id)) { - m_activeRequests[id]->abort(); - m_activeRequests.remove(id); + if (m_requestHandler.cancelRequest(id)) { logMessage(QString("Request %1 cancelled successfully").arg(id)); } else { logMessage(QString("Request %1 not found").arg(id)); } } -bool LLMClientInterface::processSingleLineCompletion(QNetworkReply *reply, - const QJsonObject &request, - const QString &accumulatedCompletion) -{ - int newlinePos = accumulatedCompletion.indexOf('\n'); - - if (newlinePos != -1) { - QString singleLineCompletion = accumulatedCompletion.left(newlinePos).trimmed(); - singleLineCompletion = removeStopWords(singleLineCompletion); - - QJsonObject position = request["params"].toObject()["doc"].toObject()["position"].toObject(); - - sendCompletionToClient(singleLineCompletion, request, position, true); - m_accumulatedResponses.remove(reply); - reply->abort(); - - return true; - } - return false; -} - -QString LLMClientInterface::сontextBefore(TextEditor::TextEditorWidget *widget, - int lineNumber, - int cursorPosition) -{ - if (!widget) - return QString(); - - DocumentContextReader reader(widget->textDocument()); - const auto ©right = reader.copyrightInfo(); - - logMessage(QString{"Line Number: %1"}.arg(lineNumber)); - logMessage(QString("Copyright found %1 %2").arg(copyright.found).arg(copyright.endLine)); - if (lineNumber < reader.findCopyright().endLine) - return QString(); - - QString contextBefore; - if (Settings::contextSettings().readFullFile()) { - contextBefore = reader.readWholeFileBefore(lineNumber, cursorPosition); - } else { - contextBefore - = reader.getContextBefore(lineNumber, - cursorPosition, - Settings::contextSettings().readStringsBeforeCursor()); - } - - return contextBefore; -} - -QString LLMClientInterface::сontextAfter(TextEditor::TextEditorWidget *widget, - int lineNumber, - int cursorPosition) -{ - if (!widget) - return QString(); - - DocumentContextReader reader(widget->textDocument()); - if (lineNumber < reader.findCopyright().endLine) - return QString(); - - QString contextAfter; - if (Settings::contextSettings().readFullFile()) { - contextAfter = reader.readWholeFileAfter(lineNumber, cursorPosition); - } else { - contextAfter = reader.getContextAfter(lineNumber, - cursorPosition, - Settings::contextSettings().readStringsAfterCursor()); - } - - return contextAfter; -} - void LLMClientInterface::handleInitialize(const QJsonObject &request) { QJsonObject response; @@ -217,40 +143,26 @@ void LLMClientInterface::handleExit(const QJsonObject &request) emit finished(); } -void LLMClientInterface::handleLLMResponse(QNetworkReply *reply, const QJsonObject &request) +void LLMClientInterface::handleCompletion(const QJsonObject &request) { - QString &accumulatedResponse = m_accumulatedResponses[reply]; + auto updatedContext = prepareContext(request); - auto &templateManager = PromptTemplateManager::instance(); - const Templates::PromptTemplate *currentTemplate = templateManager.getCurrentTemplate(); + LLMConfig config; + config.requestType = RequestType::Fim; + config.provider = LLMProvidersManager::instance().getCurrentFimProvider(); + config.promptTemplate = PromptTemplateManager::instance().getCurrentFimTemplate(); + config.url = QUrl(QString("%1%2").arg(Settings::generalSettings().url(), + Settings::generalSettings().endPoint())); - auto &providerManager = LLMProvidersManager::instance(); - bool isComplete = providerManager.getCurrentProvider()->handleResponse(reply, - accumulatedResponse); + config.providerRequest = {{"model", Settings::generalSettings().modelName.value()}, + {"stream", true}, + {"stop", + QJsonArray::fromStringList(config.promptTemplate->stopWords())}}; - QJsonObject position = request["params"].toObject()["doc"].toObject()["position"].toObject(); + config.promptTemplate->prepareRequest(config.providerRequest, updatedContext); + config.provider->prepareRequest(config.providerRequest); - if (!Settings::generalSettings().multiLineCompletion() - && processSingleLineCompletion(reply, request, accumulatedResponse)) { - return; - } - - if (isComplete || reply->isFinished()) { - if (isComplete) { - auto cleanedCompletion = removeStopWords(accumulatedResponse); - sendCompletionToClient(cleanedCompletion, request, position, true); - } else { - handleCompletion(request, accumulatedResponse); - } - m_accumulatedResponses.remove(reply); - } -} - -void LLMClientInterface::handleCompletion(const QJsonObject &request, - const QStringView &accumulatedCompletion) -{ - auto updatedContext = prepareContext(request, accumulatedCompletion); - sendLLMRequest(request, updatedContext); + m_requestHandler.sendLLMRequest(config, request); } ContextData LLMClientInterface::prepareContext(const QJsonObject &request, @@ -273,39 +185,16 @@ ContextData LLMClientInterface::prepareContext(const QJsonObject &request, int cursorPosition = position["character"].toInt(); int lineNumber = position["line"].toInt(); - auto textEditor = TextEditor::BaseTextEditor::currentTextEditor(); - TextEditor::TextEditorWidget *widget = textEditor->editorWidget(); - - DocumentContextReader reader(widget->textDocument()); - - QString recentChanges = ChangesManager::instance().getRecentChangesContext(textDocument); - - QString contextBefore = сontextBefore(widget, lineNumber, cursorPosition); - QString contextAfter = сontextAfter(widget, lineNumber, cursorPosition); - QString instructions - = QString("%1%2%3").arg(Settings::contextSettings().useSpecificInstructions() - ? reader.getSpecificInstructions() - : QString(), - Settings::contextSettings().useFilePathInContext() - ? reader.getLanguageAndFileInfo() - : QString(), - Settings::contextSettings().useProjectChangesCache() ? recentChanges - : QString()); - - return {QString("%1%2").arg(contextBefore, accumulatedCompletion), contextAfter, instructions}; -} - -void LLMClientInterface::updateProvider() -{ - m_serverUrl = QUrl(QString("%1%2").arg(Settings::generalSettings().url(), - Settings::generalSettings().endPoint())); + DocumentContextReader reader(textDocument); + return reader.prepareContext(lineNumber, cursorPosition); } void LLMClientInterface::sendCompletionToClient(const QString &completion, const QJsonObject &request, - const QJsonObject &position, bool isComplete) { + QJsonObject position = request["params"].toObject()["doc"].toObject()["position"].toObject(); + QJsonObject response; response["jsonrpc"] = "2.0"; response[LanguageServerProtocol::idKey] = request["id"]; @@ -337,69 +226,6 @@ void LLMClientInterface::sendCompletionToClient(const QString &completion, emit messageReceived(LanguageServerProtocol::JsonRpcMessage(response)); } -void LLMClientInterface::sendLLMRequest(const QJsonObject &request, const ContextData &prompt) -{ - QJsonObject providerRequest = {{"model", Settings::generalSettings().modelName.value()}, - {"stream", true}}; - - auto currentTemplate = PromptTemplateManager::instance().getCurrentTemplate(); - currentTemplate->prepareRequest(providerRequest, prompt); - - auto &providerManager = LLMProvidersManager::instance(); - providerManager.getCurrentProvider()->prepareRequest(providerRequest); - - logMessage(QString("Sending request to llm: \nurl: %1\nRequest body:\n%2") - .arg(m_serverUrl.toString(), - QString::fromUtf8( - QJsonDocument(providerRequest).toJson(QJsonDocument::Indented)))); - - QNetworkRequest networkRequest(m_serverUrl); - networkRequest.setHeader(QNetworkRequest::ContentTypeHeader, "application/json"); - - if (providerRequest.contains("api_key")) { - QString apiKey = providerRequest["api_key"].toString(); - networkRequest.setRawHeader("Authorization", QString("Bearer %1").arg(apiKey).toUtf8()); - providerRequest.remove("api_key"); - } - - QNetworkReply *reply = m_manager->post(networkRequest, QJsonDocument(providerRequest).toJson()); - if (!reply) { - logMessage("Error: Failed to create network reply"); - return; - } - - QString requestId = request["id"].toString(); - m_activeRequests[requestId] = reply; - - connect(reply, &QNetworkReply::readyRead, this, [this, reply, request]() { - handleLLMResponse(reply, request); - }); - - connect(reply, &QNetworkReply::finished, this, [this, reply, requestId]() { - reply->deleteLater(); - m_activeRequests.remove(requestId); - if (reply->error() != QNetworkReply::NoError) { - logMessage(QString("Error in QodeAssist request: %1").arg(reply->errorString())); - } else { - logMessage("Request finished successfully"); - } - }); -} - -QString LLMClientInterface::removeStopWords(const QStringView &completion) -{ - QString filteredCompletion = completion.toString(); - - auto currentTemplate = PromptTemplateManager::instance().getCurrentTemplate(); - QStringList stopWords = currentTemplate->stopWords(); - - for (const QString &stopWord : stopWords) { - filteredCompletion = filteredCompletion.replace(stopWord, ""); - } - - return filteredCompletion; -} - void LLMClientInterface::startTimeMeasurement(const QString &requestId) { m_requestStartTimes[requestId] = QDateTime::currentMSecsSinceEpoch(); diff --git a/LLMClientInterface.hpp b/LLMClientInterface.hpp index 3de65d5..2e5ce43 100644 --- a/LLMClientInterface.hpp +++ b/LLMClientInterface.hpp @@ -23,6 +23,7 @@ #include #include "QodeAssistData.hpp" +#include "core/LLMRequestHandler.hpp" class QNetworkReply; class QNetworkAccessManager; @@ -36,22 +37,13 @@ class LLMClientInterface : public LanguageClient::BaseClientInterface public: LLMClientInterface(); -public: Utils::FilePath serverDeviceTemplate() const override; void sendCompletionToClient(const QString &completion, const QJsonObject &request, - const QJsonObject &position, bool isComplete); - void handleCompletion(const QJsonObject &request, - const QStringView &accumulatedCompletion = QString()); - void sendLLMRequest(const QJsonObject &request, const ContextData &prompt); - void handleLLMResponse(QNetworkReply *reply, const QJsonObject &request); - - ContextData prepareContext(const QJsonObject &request, - const QStringView &accumulatedCompletion = QString{}); - void updateProvider(); + void handleCompletion(const QJsonObject &request); protected: void startImpl() override; @@ -65,19 +57,11 @@ private: void handleInitialized(const QJsonObject &request); void handleExit(const QJsonObject &request); void handleCancelRequest(const QJsonObject &request); - bool processSingleLineCompletion(QNetworkReply *reply, - const QJsonObject &request, - const QString &accumulatedCompletion); - QString сontextBefore(TextEditor::TextEditorWidget *widget, int lineNumber, int cursorPosition); - QString сontextAfter(TextEditor::TextEditorWidget *widget, int lineNumber, int cursorPosition); - QString removeStopWords(const QStringView &completion); - - QUrl m_serverUrl; - QNetworkAccessManager *m_manager; - QMap m_activeRequests; - QMap m_accumulatedResponses; + ContextData prepareContext(const QJsonObject &request, + const QStringView &accumulatedCompletion = QString{}); + LLMRequestHandler m_requestHandler; QElapsedTimer m_completionTimer; QMap m_requestStartTimes; diff --git a/LLMProvidersManager.cpp b/LLMProvidersManager.cpp index 7fdccb8..ebbaad1 100644 --- a/LLMProvidersManager.cpp +++ b/LLMProvidersManager.cpp @@ -19,6 +19,8 @@ #include "LLMProvidersManager.hpp" +#include "QodeAssistUtils.hpp" + namespace QodeAssist { LLMProvidersManager &LLMProvidersManager::instance() @@ -27,25 +29,53 @@ LLMProvidersManager &LLMProvidersManager::instance() return instance; } -QStringList LLMProvidersManager::getProviderNames() const +Providers::LLMProvider *LLMProvidersManager::setCurrentFimProvider(const QString &name) { - return m_providers.keys(); -} - -void LLMProvidersManager::setCurrentProvider(const QString &name) -{ - if (m_providers.contains(name)) { - m_currentProviderName = name; - } -} - -Providers::LLMProvider *LLMProvidersManager::getCurrentProvider() -{ - if (m_currentProviderName.isEmpty()) { + logMessage("Setting current FIM provider to: " + name); + if (!m_providers.contains(name)) { + logMessage("Can't find provider with name: " + name); return nullptr; } - return m_providers[m_currentProviderName]; + m_currentFimProvider = m_providers[name]; + return m_currentFimProvider; +} + +Providers::LLMProvider *LLMProvidersManager::setCurrentChatProvider(const QString &name) +{ + logMessage("Setting current chat provider to: " + name); + if (!m_providers.contains(name)) { + logMessage("Can't find chat provider with name: " + name); + return nullptr; + } + + m_currentChatProvider = m_providers[name]; + return m_currentChatProvider; +} + +Providers::LLMProvider *LLMProvidersManager::getCurrentFimProvider() +{ + if (m_currentFimProvider == nullptr) { + logMessage("Current fim provider is null"); + return nullptr; + } + + return m_currentFimProvider; +} + +Providers::LLMProvider *LLMProvidersManager::getCurrentChatProvider() +{ + if (m_currentChatProvider == nullptr) { + logMessage("Current chat provider is null"); + return nullptr; + } + + return m_currentChatProvider; +} + +QStringList LLMProvidersManager::providersNames() const +{ + return m_providers.keys(); } LLMProvidersManager::~LLMProvidersManager() diff --git a/LLMProvidersManager.hpp b/LLMProvidersManager.hpp index c9698b6..b6f3305 100644 --- a/LLMProvidersManager.hpp +++ b/LLMProvidersManager.hpp @@ -29,6 +29,7 @@ class LLMProvidersManager { public: static LLMProvidersManager &instance(); + ~LLMProvidersManager(); template void registerProvider() @@ -40,11 +41,13 @@ public: m_providers[name] = provider; } - QStringList getProviderNames() const; - void setCurrentProvider(const QString &name); - Providers::LLMProvider *getCurrentProvider(); + Providers::LLMProvider *setCurrentFimProvider(const QString &name); + Providers::LLMProvider *setCurrentChatProvider(const QString &name); - ~LLMProvidersManager(); + Providers::LLMProvider *getCurrentFimProvider(); + Providers::LLMProvider *getCurrentChatProvider(); + + QStringList providersNames() const; private: LLMProvidersManager() = default; @@ -52,7 +55,8 @@ private: LLMProvidersManager &operator=(const LLMProvidersManager &) = delete; QMap m_providers; - QString m_currentProviderName; + Providers::LLMProvider *m_currentFimProvider = nullptr; + Providers::LLMProvider *m_currentChatProvider = nullptr; }; } // namespace QodeAssist diff --git a/PromptTemplateManager.cpp b/PromptTemplateManager.cpp index 319f200..559f014 100644 --- a/PromptTemplateManager.cpp +++ b/PromptTemplateManager.cpp @@ -19,6 +19,8 @@ #include "PromptTemplateManager.hpp" +#include "QodeAssistUtils.hpp" + namespace QodeAssist { PromptTemplateManager &PromptTemplateManager::instance() @@ -27,27 +29,60 @@ PromptTemplateManager &PromptTemplateManager::instance() return instance; } -void PromptTemplateManager::setCurrentTemplate(const QString &name) +void PromptTemplateManager::setCurrentFimTemplate(const QString &name) { - if (m_templates.contains(name)) { - m_currentTemplateName = name; + logMessage("Setting current FIM provider to: " + name); + if (!m_fimTemplates.contains(name) || m_fimTemplates[name] == nullptr) { + logMessage("Error to set current FIM template" + name); + return; } + + m_currentFimTemplate = m_fimTemplates[name]; } -const Templates::PromptTemplate *PromptTemplateManager::getCurrentTemplate() const +Templates::PromptTemplate *PromptTemplateManager::getCurrentFimTemplate() { - auto it = m_templates.find(m_currentTemplateName); - return it != m_templates.end() ? it.value() : nullptr; + if (m_currentFimTemplate == nullptr) { + logMessage("Current fim provider is null"); + return nullptr; + } + + return m_currentFimTemplate; } -QStringList PromptTemplateManager::getTemplateNames() const +void PromptTemplateManager::setCurrentChatTemplate(const QString &name) { - return m_templates.keys(); + logMessage("Setting current chat provider to: " + name); + if (!m_chatTemplates.contains(name) || m_chatTemplates[name] == nullptr) { + logMessage("Error to set current chat template" + name); + return; + } + + m_currentChatTemplate = m_chatTemplates[name]; +} + +Templates::PromptTemplate *PromptTemplateManager::getCurrentChatTemplate() +{ + if (m_currentChatTemplate == nullptr) + logMessage("Current chat provider is null"); + + return m_currentChatTemplate; +} + +QStringList PromptTemplateManager::fimTemplatesNames() const +{ + return m_fimTemplates.keys(); +} + +QStringList PromptTemplateManager::chatTemplatesNames() const +{ + return m_chatTemplates.keys(); } PromptTemplateManager::~PromptTemplateManager() { - qDeleteAll(m_templates); + qDeleteAll(m_fimTemplates); + qDeleteAll(m_chatTemplates); } } // namespace QodeAssist diff --git a/PromptTemplateManager.hpp b/PromptTemplateManager.hpp index 9cdbd5d..47347d4 100644 --- a/PromptTemplateManager.hpp +++ b/PromptTemplateManager.hpp @@ -30,6 +30,7 @@ class PromptTemplateManager { public: static PromptTemplateManager &instance(); + ~PromptTemplateManager(); template void registerTemplate() @@ -38,22 +39,31 @@ public: "T must inherit from PromptTemplate"); T *template_ptr = new T(); QString name = template_ptr->name(); - m_templates[name] = template_ptr; + if (template_ptr->type() == Templates::TemplateType::Fim) { + m_fimTemplates[name] = template_ptr; + } else if (template_ptr->type() == Templates::TemplateType::Chat) { + m_chatTemplates[name] = template_ptr; + } } - void setCurrentTemplate(const QString &name); - const Templates::PromptTemplate *getCurrentTemplate() const; - QStringList getTemplateNames() const; + void setCurrentFimTemplate(const QString &name); + Templates::PromptTemplate *getCurrentFimTemplate(); - ~PromptTemplateManager(); + void setCurrentChatTemplate(const QString &name); + Templates::PromptTemplate *getCurrentChatTemplate(); + + QStringList fimTemplatesNames() const; + QStringList chatTemplatesNames() const; private: PromptTemplateManager() = default; PromptTemplateManager(const PromptTemplateManager &) = delete; PromptTemplateManager &operator=(const PromptTemplateManager &) = delete; - QMap m_templates; - QString m_currentTemplateName; + QMap m_fimTemplates; + QMap m_chatTemplates; + Templates::PromptTemplate *m_currentFimTemplate; + Templates::PromptTemplate *m_currentChatTemplate; }; } // namespace QodeAssist diff --git a/QodeAssistConstants.hpp b/QodeAssistConstants.hpp index 8f51de7..fc5b6d8 100644 --- a/QodeAssistConstants.hpp +++ b/QodeAssistConstants.hpp @@ -61,6 +61,12 @@ const char USE_FILE_PATH_IN_CONTEXT[] = "QodeAssist.useFilePathInContext"; const char CUSTOM_JSON_TEMPLATE[] = "QodeAssist.customJsonTemplate"; const char USE_PROJECT_CHANGES_CACHE[] = "QodeAssist.useProjectChangesCache"; const char MAX_CHANGES_CACHE_SIZE[] = "QodeAssist.maxChangesCacheSize"; +const char CHAT_LLM_PROVIDERS[] = "QodeAssist.chatLlmProviders"; +const char CHAT_URL[] = "QodeAssist.chatUrl"; +const char CHAT_END_POINT[] = "QodeAssist.chatEndPoint"; +const char CHAT_MODEL_NAME[] = "QodeAssist.chatModelName"; +const char CHAT_SELECT_MODELS[] = "QodeAssist.chatSelectModels"; +const char CHAT_PROMPTS[] = "QodeAssist.chatPrompts"; const char QODE_ASSIST_GENERAL_OPTIONS_ID[] = "QodeAssist.GeneralOptions"; const char QODE_ASSIST_GENERAL_SETTINGS_PAGE_ID[] = "QodeAssist.1GeneralSettingsPageId"; diff --git a/QodeAssistUtils.hpp b/QodeAssistUtils.hpp index 63ab992..677c1a6 100644 --- a/QodeAssistUtils.hpp +++ b/QodeAssistUtils.hpp @@ -47,6 +47,7 @@ inline void logMessage(const QString &message, bool silent = true) return; const QString prefixedMessage = QLatin1String("[Qode Assist] ") + message; + qDebug() << prefixedMessage; if (silent) { Core::MessageManager::writeSilently(prefixedMessage); } else { @@ -60,6 +61,8 @@ inline void logMessages(const QStringList &messages, bool silent = true) return; QStringList prefixedMessages; + qDebug() << prefixedMessages; + for (const QString &message : messages) { prefixedMessages << (QLatin1String("[Qode Assist] ") + message); } diff --git a/chat/ChatClientInterface.cpp b/chat/ChatClientInterface.cpp new file mode 100644 index 0000000..34d7942 --- /dev/null +++ b/chat/ChatClientInterface.cpp @@ -0,0 +1,153 @@ +/* + * Copyright (C) 2024 Petr Mironychev + * + * This file is part of QodeAssist. + * + * QodeAssist is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * QodeAssist is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with QodeAssist. If not, see . + */ + +#include "ChatClientInterface.hpp" +#include "LLMProvidersManager.hpp" +#include "PromptTemplateManager.hpp" +#include "QodeAssistUtils.hpp" +#include "settings/ContextSettings.hpp" +#include "settings/GeneralSettings.hpp" +#include "settings/PresetPromptsSettings.hpp" + +#include +#include +#include + +namespace QodeAssist::Chat { + +ChatClientInterface::ChatClientInterface(QObject *parent) + : QObject(parent) + , m_requestHandler(new LLMRequestHandler(this)) +{ + connect(m_requestHandler, + &LLMRequestHandler::completionReceived, + this, + [this](const QString &completion, const QJsonObject &, bool isComplete) { + handleLLMResponse(completion, isComplete); + }); + + connect(m_requestHandler, + &LLMRequestHandler::requestFinished, + this, + [this](const QString &, bool success, const QString &errorString) { + if (!success) { + emit errorOccurred(errorString); + } + }); + + // QJsonObject systemMessage; + // systemMessage["role"] = "system"; + // systemMessage["content"] = "You are a helpful C++ and QML programming assistant."; + // m_chatHistory.append(systemMessage); +} + +ChatClientInterface::~ChatClientInterface() +{ +} + +void ChatClientInterface::sendMessage(const QString &message) +{ + logMessage("Sending message: " + message); + logMessage("chatProvider " + Settings::generalSettings().chatLlmProviders.stringValue()); + logMessage("chatTemplate " + Settings::generalSettings().chatPrompts.stringValue()); + + auto chatTemplate = PromptTemplateManager::instance().getCurrentChatTemplate(); + auto chatProvider = LLMProvidersManager::instance().getCurrentChatProvider(); + + ContextData context; + context.prefix = message; + context.suffix = ""; + if (Settings::contextSettings().useSpecificInstructions()) + context.instriuctions = Settings::contextSettings().specificInstractions(); + + QJsonObject providerRequest; + providerRequest["model"] = Settings::generalSettings().chatModelName(); + providerRequest["stream"] = true; + + providerRequest["messages"] = m_chatHistory; + + chatTemplate->prepareRequest(providerRequest, context); + chatProvider->prepareRequest(providerRequest); + + m_chatHistory = providerRequest["messages"].toArray(); + + LLMConfig config; + config.requestType = RequestType::Chat; + config.provider = chatProvider; + config.promptTemplate = chatTemplate; + config.url = QString("%1%2").arg(Settings::generalSettings().chatUrl(), + Settings::generalSettings().chatEndPoint()); + config.providerRequest = providerRequest; + + QJsonObject request; + request["id"] = QUuid::createUuid().toString(); + + m_accumulatedResponse.clear(); + m_pendingMessage = message; + m_requestHandler->sendLLMRequest(config, request); +} + +void ChatClientInterface::handleLLMResponse(const QString &response, bool isComplete) +{ + m_accumulatedResponse += response; + logMessage("Accumulated response: " + m_accumulatedResponse); + + if (isComplete) { + logMessage("Message completed. Final response: " + m_accumulatedResponse); + emit messageReceived(m_accumulatedResponse.trimmed()); + + QJsonObject assistantMessage; + assistantMessage["role"] = "assistant"; + assistantMessage["content"] = m_accumulatedResponse.trimmed(); + m_chatHistory.append(assistantMessage); + + m_pendingMessage.clear(); + m_accumulatedResponse.clear(); + + trimChatHistory(); + } +} + +void ChatClientInterface::trimChatHistory() +{ + int maxTokens = 4000; + int totalTokens = 0; + QJsonArray newHistory; + + if (!m_chatHistory.isEmpty() + && m_chatHistory.first().toObject()["role"].toString() == "system") { + newHistory.append(m_chatHistory.first()); + } + + for (int i = m_chatHistory.size() - 1; i >= 0; --i) { + QJsonObject message = m_chatHistory[i].toObject(); + int messageTokens = message["content"].toString().length() / 4; + + if (totalTokens + messageTokens > maxTokens) { + break; + } + + newHistory.prepend(message); + totalTokens += messageTokens; + } + + m_chatHistory = newHistory; +} + +} // namespace QodeAssist::Chat diff --git a/chat/ChatClientInterface.hpp b/chat/ChatClientInterface.hpp new file mode 100644 index 0000000..f1d7b69 --- /dev/null +++ b/chat/ChatClientInterface.hpp @@ -0,0 +1,53 @@ +/* + * Copyright (C) 2024 Petr Mironychev + * + * This file is part of QodeAssist. + * + * QodeAssist is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * QodeAssist is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with QodeAssist. If not, see . + */ + +#pragma once + +#include +#include +#include "QodeAssistData.hpp" +#include "core/LLMRequestHandler.hpp" + +namespace QodeAssist::Chat { + +class ChatClientInterface : public QObject +{ + Q_OBJECT + +public: + explicit ChatClientInterface(QObject *parent = nullptr); + ~ChatClientInterface(); + + void sendMessage(const QString &message); + +signals: + void messageReceived(const QString &message); + void errorOccurred(const QString &error); + +private: + void handleLLMResponse(const QString &response, bool isComplete); + void trimChatHistory(); + + LLMRequestHandler *m_requestHandler; + QString m_accumulatedResponse; + QString m_pendingMessage; + QJsonArray m_chatHistory; +}; + +} // namespace QodeAssist::Chat diff --git a/chat/ChatOutputPane.cpp b/chat/ChatOutputPane.cpp new file mode 100644 index 0000000..4721601 --- /dev/null +++ b/chat/ChatOutputPane.cpp @@ -0,0 +1,95 @@ +/* + * Copyright (C) 2024 Petr Mironychev + * + * This file is part of QodeAssist. + * + * QodeAssist is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * QodeAssist is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with QodeAssist. If not, see . + */ + +#include "ChatOutputPane.h" + +#include "QodeAssisttr.h" + +namespace QodeAssist::Chat { + +ChatOutputPane::ChatOutputPane(QObject *parent) + : Core::IOutputPane(parent) + , m_chatWidget(new ChatWidget) +{ + setId("QodeAssistChat"); + setDisplayName(Tr::tr("QodeAssist Chat")); + setPriorityInStatusBar(-40); +} + +ChatOutputPane::~ChatOutputPane() +{ + delete m_chatWidget; +} + +QWidget *ChatOutputPane::outputWidget(QWidget *) +{ + return m_chatWidget; +} + +QList ChatOutputPane::toolBarWidgets() const +{ + return {}; +} + +void ChatOutputPane::clearContents() +{ + m_chatWidget->clear(); +} + +void ChatOutputPane::visibilityChanged(bool visible) +{ + if (visible) + m_chatWidget->scrollToBottom(); +} + +void ChatOutputPane::setFocus() +{ + m_chatWidget->setFocus(); +} + +bool ChatOutputPane::hasFocus() const +{ + return m_chatWidget->hasFocus(); +} + +bool ChatOutputPane::canFocus() const +{ + return true; +} + +bool ChatOutputPane::canNavigate() const +{ + return false; +} + +bool ChatOutputPane::canNext() const +{ + return false; +} + +bool ChatOutputPane::canPrevious() const +{ + return false; +} + +void ChatOutputPane::goToNext() {} + +void ChatOutputPane::goToPrev() {} + +} // namespace QodeAssist::Chat diff --git a/chat/ChatOutputPane.h b/chat/ChatOutputPane.h new file mode 100644 index 0000000..12e8267 --- /dev/null +++ b/chat/ChatOutputPane.h @@ -0,0 +1,52 @@ +/* + * Copyright (C) 2024 Petr Mironychev + * + * This file is part of QodeAssist. + * + * QodeAssist is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * QodeAssist is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with QodeAssist. If not, see . + */ + +#pragma once + +#include "ChatWidget.h" +#include + +namespace QodeAssist::Chat { + +class ChatOutputPane : public Core::IOutputPane +{ + Q_OBJECT + +public: + explicit ChatOutputPane(QObject *parent = nullptr); + ~ChatOutputPane() override; + + QWidget *outputWidget(QWidget *parent) override; + QList toolBarWidgets() const override; + void clearContents() override; + void visibilityChanged(bool visible) override; + void setFocus() override; + bool hasFocus() const override; + bool canFocus() const override; + bool canNavigate() const override; + bool canNext() const override; + bool canPrevious() const override; + void goToNext() override; + void goToPrev() override; + +private: + ChatWidget *m_chatWidget; +}; + +} // namespace QodeAssist::Chat diff --git a/chat/ChatWidget.cpp b/chat/ChatWidget.cpp new file mode 100644 index 0000000..dadc8f0 --- /dev/null +++ b/chat/ChatWidget.cpp @@ -0,0 +1,150 @@ +/* + * Copyright (C) 2024 Petr Mironychev + * + * This file is part of QodeAssist. + * + * QodeAssist is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * QodeAssist is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with QodeAssist. If not, see . + */ + +#include "ChatWidget.h" +#include "QodeAssistUtils.hpp" + +#include +#include +#include +#include +#include + +namespace QodeAssist::Chat { + +ChatWidget::ChatWidget(QWidget *parent) + : QWidget(parent) + , m_showTimestamp(false) + , m_chatClient(new ChatClientInterface(this)) +{ + setupUi(); + + connect(m_sendButton, &QPushButton::clicked, this, &ChatWidget::sendMessage); + connect(m_messageInput, &QLineEdit::returnPressed, this, &ChatWidget::sendMessage); + + connect(m_chatClient, &ChatClientInterface::messageReceived, this, &ChatWidget::receiveMessage); + connect(m_chatClient, &ChatClientInterface::errorOccurred, this, &ChatWidget::handleError); + + logMessage("ChatWidget initialized"); +} + +void ChatWidget::setupUi() +{ + m_chatDisplay = new QTextEdit(this); + m_chatDisplay->setReadOnly(true); + + m_messageInput = new QLineEdit(this); + m_sendButton = new QPushButton("Send", this); + + QHBoxLayout *inputLayout = new QHBoxLayout; + inputLayout->addWidget(m_messageInput); + inputLayout->addWidget(m_sendButton); + + QVBoxLayout *mainLayout = new QVBoxLayout(this); + mainLayout->addWidget(m_chatDisplay); + mainLayout->addLayout(inputLayout); + + setLayout(mainLayout); +} + +void ChatWidget::sendMessage() +{ + QString message = m_messageInput->text().trimmed(); + if (!message.isEmpty()) { + logMessage("Sending message: " + message); + addMessage(message, true); + m_chatClient->sendMessage(message); + m_messageInput->clear(); + addMessage("AI is typing...", false); + } +} + +void ChatWidget::receiveMessage(const QString &message) +{ + logMessage("Received message: " + message); + updateLastAIMessage(message); +} + +void ChatWidget::receivePartialMessage(const QString &partialMessage) +{ + logMessage("Received partial message: " + partialMessage); + m_currentAIResponse += partialMessage; + updateLastAIMessage(m_currentAIResponse); +} + +void ChatWidget::onMessageCompleted() +{ + logMessage("Message completed. Final response: " + m_currentAIResponse); + updateLastAIMessage(m_currentAIResponse); + m_currentAIResponse.clear(); + scrollToBottom(); +} + +void ChatWidget::handleError(const QString &error) +{ + logMessage("Error occurred: " + error); + addMessage("Error: " + error, false); +} + +void ChatWidget::addMessage(const QString &message, bool fromUser) +{ + auto prefix = fromUser ? "You: " : "AI: "; + QString timestamp = m_showTimestamp ? QDateTime::currentDateTime().toString("[hh:mm:ss] ") : ""; + QString fullMessage = timestamp + prefix + message; + logMessage("Adding message to display: " + fullMessage); + m_chatDisplay->append(fullMessage); + scrollToBottom(); +} + +void ChatWidget::updateLastAIMessage(const QString &message) +{ + logMessage("Updating last AI message: " + message); + QTextCursor cursor = m_chatDisplay->textCursor(); + cursor.movePosition(QTextCursor::End); + cursor.movePosition(QTextCursor::StartOfBlock, QTextCursor::KeepAnchor); + cursor.removeSelectedText(); + + QString timestamp = m_showTimestamp ? QDateTime::currentDateTime().toString("[hh:mm:ss] ") : ""; + cursor.insertText(timestamp + "AI: " + message); + + cursor.movePosition(QTextCursor::End); + m_chatDisplay->setTextCursor(cursor); + + scrollToBottom(); + m_chatDisplay->repaint(); +} + +void ChatWidget::clear() +{ + m_chatDisplay->clear(); + m_currentAIResponse.clear(); +} + +void ChatWidget::scrollToBottom() +{ + QScrollBar *scrollBar = m_chatDisplay->verticalScrollBar(); + scrollBar->setValue(scrollBar->maximum()); +} + +void ChatWidget::setShowTimestamp(bool show) +{ + m_showTimestamp = show; +} + +} // namespace QodeAssist::Chat diff --git a/chat/ChatWidget.h b/chat/ChatWidget.h new file mode 100644 index 0000000..94bbba0 --- /dev/null +++ b/chat/ChatWidget.h @@ -0,0 +1,62 @@ +/* + * Copyright (C) 2024 Petr Mironychev + * + * This file is part of QodeAssist. + * + * QodeAssist is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * QodeAssist is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with QodeAssist. If not, see . + */ + +#pragma once + +#include +#include +#include +#include + +#include "ChatClientInterface.hpp" + +namespace QodeAssist::Chat { + +class ChatWidget : public QWidget +{ + Q_OBJECT + +public: + explicit ChatWidget(QWidget *parent = nullptr); + + void clear(); + void scrollToBottom(); + void setShowTimestamp(bool show); + + void receiveMessage(const QString &message); +private slots: + void sendMessage(); + void receivePartialMessage(const QString &partialMessage); + void onMessageCompleted(); + void handleError(const QString &error); + +private: + QTextEdit *m_chatDisplay; + QLineEdit *m_messageInput; + QPushButton *m_sendButton; + bool m_showTimestamp; + ChatClientInterface *m_chatClient; + QString m_currentAIResponse; + + void setupUi(); + void addMessage(const QString &message, bool fromUser = true); + void updateLastAIMessage(const QString &message); +}; + +} // namespace QodeAssist::Chat diff --git a/core/LLMRequestConfig.hpp b/core/LLMRequestConfig.hpp new file mode 100644 index 0000000..c2fe73b --- /dev/null +++ b/core/LLMRequestConfig.hpp @@ -0,0 +1,40 @@ +/* + * Copyright (C) 2024 Petr Mironychev + * + * This file is part of QodeAssist. + * + * QodeAssist is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * QodeAssist is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with QodeAssist. If not, see . + */ + +#pragma once + +#include +#include +#include "providers/LLMProvider.hpp" +#include "templates/PromptTemplate.hpp" + +namespace QodeAssist { + +enum class RequestType { Fim, Chat }; + +struct LLMConfig +{ + QUrl url; + Providers::LLMProvider *provider; + Templates::PromptTemplate *promptTemplate; + QJsonObject providerRequest; + RequestType requestType; +}; + +} // namespace QodeAssist diff --git a/core/LLMRequestHandler.cpp b/core/LLMRequestHandler.cpp new file mode 100644 index 0000000..bb5b3ae --- /dev/null +++ b/core/LLMRequestHandler.cpp @@ -0,0 +1,162 @@ +/* + * Copyright (C) 2024 Petr Mironychev + * + * This file is part of QodeAssist. + * + * QodeAssist is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * QodeAssist is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with QodeAssist. If not, see . + */ + +#include "LLMRequestHandler.hpp" +#include "LLMProvidersManager.hpp" +#include "QodeAssistUtils.hpp" +#include "settings/GeneralSettings.hpp" + +#include +#include + +namespace QodeAssist { + +LLMRequestHandler::LLMRequestHandler(QObject *parent) + : QObject(parent) + , m_manager(new QNetworkAccessManager(this)) +{} + +void LLMRequestHandler::sendLLMRequest(const LLMConfig &config, const QJsonObject &request) +{ + logMessage(QString("Sending request to llm: \nurl: %1\nRequest body:\n%2") + .arg(config.url.toString(), + QString::fromUtf8( + QJsonDocument(config.providerRequest).toJson(QJsonDocument::Indented)))); + + QNetworkRequest networkRequest(config.url); + prepareNetworkRequest(networkRequest, config.providerRequest); + + QNetworkReply *reply = m_manager->post(networkRequest, + QJsonDocument(config.providerRequest).toJson()); + if (!reply) { + logMessage("Error: Failed to create network reply"); + return; + } + + QString requestId = request["id"].toString(); + m_activeRequests[requestId] = reply; + + connect(reply, &QNetworkReply::readyRead, this, [this, reply, request, config]() { + handleLLMResponse(reply, request, config); + }); + + connect(reply, &QNetworkReply::finished, this, [this, reply, requestId]() { + reply->deleteLater(); + m_activeRequests.remove(requestId); + if (reply->error() != QNetworkReply::NoError) { + logMessage(QString("Error in QodeAssist request: %1").arg(reply->errorString())); + emit requestFinished(requestId, false, reply->errorString()); + } else { + logMessage("Request finished successfully"); + emit requestFinished(requestId, true, QString()); + } + }); +} + +void LLMRequestHandler::handleLLMResponse(QNetworkReply *reply, + const QJsonObject &request, + const LLMConfig &config) +{ + qDebug() << "Handling LLM response" << request; + + QString &accumulatedResponse = m_accumulatedResponses[reply]; + + bool isComplete = config.provider->handleResponse(reply, accumulatedResponse); + + if (config.requestType == RequestType::Fim) { + if (!Settings::generalSettings().multiLineCompletion() + && processSingleLineCompletion(reply, request, accumulatedResponse, config)) { + return; + } + } + + if (isComplete || reply->isFinished()) { + if (isComplete) { + if (config.requestType == RequestType::Fim) { + auto cleanedCompletion = removeStopWords(accumulatedResponse, + config.promptTemplate->stopWords()); + emit completionReceived(cleanedCompletion, request, true); + } else { + emit completionReceived(accumulatedResponse, request, true); + } + } else { + emit completionReceived(accumulatedResponse, request, false); + } + m_accumulatedResponses.remove(reply); + } +} + +bool LLMRequestHandler::cancelRequest(const QString &id) +{ + if (m_activeRequests.contains(id)) { + QNetworkReply *reply = m_activeRequests[id]; + reply->abort(); + m_activeRequests.remove(id); + m_accumulatedResponses.remove(reply); + emit requestCancelled(id); + return true; + } + return false; +} + +void LLMRequestHandler::prepareNetworkRequest(QNetworkRequest &networkRequest, + const QJsonObject &providerRequest) +{ + networkRequest.setHeader(QNetworkRequest::ContentTypeHeader, "application/json"); + + if (providerRequest.contains("api_key")) { + QString apiKey = providerRequest["api_key"].toString(); + networkRequest.setRawHeader("Authorization", QString("Bearer %1").arg(apiKey).toUtf8()); + } +} + +bool LLMRequestHandler::processSingleLineCompletion(QNetworkReply *reply, + const QJsonObject &request, + const QString &accumulatedResponse, + const LLMConfig &config) +{ + int newlinePos = accumulatedResponse.indexOf('\n'); + + if (newlinePos != -1) { + QString singleLineCompletion = accumulatedResponse.left(newlinePos).trimmed(); + singleLineCompletion = removeStopWords(singleLineCompletion, + config.promptTemplate->stopWords()); + + emit completionReceived(singleLineCompletion, request, true); + m_accumulatedResponses.remove(reply); + reply->abort(); + + return true; + } + return false; +} + +QString LLMRequestHandler::removeStopWords(const QStringView &completion, + const QStringList &stopWords) +{ + QString filteredCompletion = completion.toString(); + + for (const QString &stopWord : stopWords) { + filteredCompletion = filteredCompletion.replace(stopWord, ""); + } + + return filteredCompletion; +} + +} // namespace QodeAssist diff --git a/core/LLMRequestHandler.hpp b/core/LLMRequestHandler.hpp new file mode 100644 index 0000000..aadd590 --- /dev/null +++ b/core/LLMRequestHandler.hpp @@ -0,0 +1,64 @@ +/* + * Copyright (C) 2024 Petr Mironychev + * + * This file is part of QodeAssist. + * + * QodeAssist is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * QodeAssist is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with QodeAssist. If not, see . + */ + +#pragma once + +#include +#include +#include + +#include "QodeAssistData.hpp" +#include "core/LLMRequestConfig.hpp" + +class QNetworkReply; + +namespace QodeAssist { + +class LLMRequestHandler : public QObject +{ + Q_OBJECT + +public: + explicit LLMRequestHandler(QObject *parent = nullptr); + + void sendLLMRequest(const LLMConfig &config, const QJsonObject &request); + void handleLLMResponse(QNetworkReply *reply, + const QJsonObject &request, + const LLMConfig &config); + bool cancelRequest(const QString &id); + +signals: + void completionReceived(const QString &completion, const QJsonObject &request, bool isComplete); + void requestFinished(const QString &requestId, bool success, const QString &errorString); + void requestCancelled(const QString &id); + +private: + QNetworkAccessManager *m_manager; + QMap m_activeRequests; + QMap m_accumulatedResponses; + + void prepareNetworkRequest(QNetworkRequest &networkRequest, const QJsonObject &providerRequest); + bool processSingleLineCompletion(QNetworkReply *reply, + const QJsonObject &request, + const QString &accumulatedResponse, + const LLMConfig &config); + QString removeStopWords(const QStringView &completion, const QStringList &stopWords); +}; + +} // namespace QodeAssist diff --git a/providers/LLMProvider.hpp b/providers/LLMProvider.hpp index 55cfde0..f02e639 100644 --- a/providers/LLMProvider.hpp +++ b/providers/LLMProvider.hpp @@ -35,6 +35,7 @@ public: virtual QString name() const = 0; virtual QString url() const = 0; virtual QString completionEndpoint() const = 0; + virtual QString chatEndpoint() const = 0; virtual void prepareRequest(QJsonObject &request) = 0; virtual bool handleResponse(QNetworkReply *reply, QString &accumulatedResponse) = 0; diff --git a/providers/LMStudioProvider.cpp b/providers/LMStudioProvider.cpp index dba91bc..705d90c 100644 --- a/providers/LMStudioProvider.cpp +++ b/providers/LMStudioProvider.cpp @@ -48,12 +48,14 @@ QString LMStudioProvider::completionEndpoint() const return "/v1/chat/completions"; } +QString LMStudioProvider::chatEndpoint() const +{ + return "/v1/chat/completions"; +} + void LMStudioProvider::prepareRequest(QJsonObject &request) { auto &settings = Settings::presetPromptsSettings(); - const auto ¤tTemplate = PromptTemplateManager::instance().getCurrentTemplate(); - if (currentTemplate->name() == "Custom Template") - return; if (request.contains("prompt")) { QJsonArray messages{ {QJsonObject{{"role", "user"}, {"content", request.take("prompt").toString()}}}}; @@ -62,7 +64,6 @@ void LMStudioProvider::prepareRequest(QJsonObject &request) request["max_tokens"] = settings.maxTokens(); request["temperature"] = settings.temperature(); - request["stop"] = QJsonArray::fromStringList(currentTemplate->stopWords()); if (settings.useTopP()) request["top_p"] = settings.topP(); if (settings.useTopK()) diff --git a/providers/LMStudioProvider.hpp b/providers/LMStudioProvider.hpp index 8285c32..1cfc2a9 100644 --- a/providers/LMStudioProvider.hpp +++ b/providers/LMStudioProvider.hpp @@ -31,6 +31,7 @@ public: QString name() const override; QString url() const override; QString completionEndpoint() const override; + QString chatEndpoint() const override; void prepareRequest(QJsonObject &request) override; bool handleResponse(QNetworkReply *reply, QString &accumulatedResponse) override; QList getInstalledModels(const Utils::Environment &env) override; diff --git a/providers/OllamaProvider.cpp b/providers/OllamaProvider.cpp index ab9a1ce..091d24f 100644 --- a/providers/OllamaProvider.cpp +++ b/providers/OllamaProvider.cpp @@ -48,18 +48,19 @@ QString OllamaProvider::completionEndpoint() const return "/api/generate"; } +QString OllamaProvider::chatEndpoint() const +{ + return "/api/chat"; +} + void OllamaProvider::prepareRequest(QJsonObject &request) { auto &settings = Settings::presetPromptsSettings(); - auto currentTemplate = PromptTemplateManager::instance().getCurrentTemplate(); - if (currentTemplate->name() == "Custom Template") - return; QJsonObject options; options["num_predict"] = settings.maxTokens(); options["keep_alive"] = settings.ollamaLivetime(); options["temperature"] = settings.temperature(); - options["stop"] = QJsonArray::fromStringList(currentTemplate->stopWords()); if (settings.useTopP()) options["top_p"] = settings.topP(); if (settings.useTopK()) @@ -73,28 +74,52 @@ void OllamaProvider::prepareRequest(QJsonObject &request) bool OllamaProvider::handleResponse(QNetworkReply *reply, QString &accumulatedResponse) { + QString endpoint = reply->url().path(); + bool isComplete = false; while (reply->canReadLine()) { QByteArray line = reply->readLine().trimmed(); if (line.isEmpty()) { continue; } - QJsonDocument jsonResponse = QJsonDocument::fromJson(line); - if (jsonResponse.isNull()) { - qWarning() << "Invalid JSON response from Ollama:" << line; + + QJsonDocument doc = QJsonDocument::fromJson(line); + if (doc.isNull()) { + logMessage("Invalid JSON response from Ollama: " + QString::fromUtf8(line)); continue; } - QJsonObject responseObj = jsonResponse.object(); - if (responseObj.contains("response")) { - QString completion = responseObj["response"].toString(); - accumulatedResponse += completion; + QJsonObject responseObj = doc.object(); + + if (responseObj.contains("error")) { + QString errorMessage = responseObj["error"].toString(); + logMessage("Error in Ollama response: " + errorMessage); + return false; } - if (responseObj["done"].toBool()) { + + if (endpoint == completionEndpoint()) { + if (responseObj.contains("response")) { + QString completion = responseObj["response"].toString(); + accumulatedResponse += completion; + } + } else if (endpoint == chatEndpoint()) { + if (responseObj.contains("message")) { + QJsonObject message = responseObj["message"].toObject(); + if (message.contains("content")) { + QString content = message["content"].toString(); + accumulatedResponse += content; + } + } + } else { + logMessage("Unknown endpoint: " + endpoint); + } + + if (responseObj.contains("done") && responseObj["done"].toBool()) { isComplete = true; break; } } + return isComplete; } diff --git a/providers/OllamaProvider.hpp b/providers/OllamaProvider.hpp index 229d1ae..c78f8f3 100644 --- a/providers/OllamaProvider.hpp +++ b/providers/OllamaProvider.hpp @@ -31,6 +31,7 @@ public: QString name() const override; QString url() const override; QString completionEndpoint() const override; + QString chatEndpoint() const override; void prepareRequest(QJsonObject &request) override; bool handleResponse(QNetworkReply *reply, QString &accumulatedResponse) override; QList getInstalledModels(const Utils::Environment &env) override; diff --git a/providers/OpenAICompatProvider.cpp b/providers/OpenAICompatProvider.cpp index 6214678..a366ed4 100644 --- a/providers/OpenAICompatProvider.cpp +++ b/providers/OpenAICompatProvider.cpp @@ -46,13 +46,14 @@ QString OpenAICompatProvider::completionEndpoint() const return "/v1/chat/completions"; } +QString OpenAICompatProvider::chatEndpoint() const +{ + return "/v1/chat/completions"; +} + void OpenAICompatProvider::prepareRequest(QJsonObject &request) { auto &settings = Settings::presetPromptsSettings(); - const auto ¤tTemplate = PromptTemplateManager::instance().getCurrentTemplate(); - if (currentTemplate->name() == "Custom Template") - return; - if (request.contains("prompt")) { QJsonArray messages{ {QJsonObject{{"role", "user"}, {"content", request.take("prompt").toString()}}}}; @@ -61,7 +62,6 @@ void OpenAICompatProvider::prepareRequest(QJsonObject &request) request["max_tokens"] = settings.maxTokens(); request["temperature"] = settings.temperature(); - request["stop"] = QJsonArray::fromStringList(currentTemplate->stopWords()); if (settings.useTopP()) request["top_p"] = settings.topP(); if (settings.useTopK()) diff --git a/providers/OpenAICompatProvider.hpp b/providers/OpenAICompatProvider.hpp index 8ca5824..417c643 100644 --- a/providers/OpenAICompatProvider.hpp +++ b/providers/OpenAICompatProvider.hpp @@ -31,6 +31,7 @@ public: QString name() const override; QString url() const override; QString completionEndpoint() const override; + QString chatEndpoint() const override; void prepareRequest(QJsonObject &request) override; bool handleResponse(QNetworkReply *reply, QString &accumulatedResponse) override; QList getInstalledModels(const Utils::Environment &env) override; diff --git a/qodeassist.cpp b/qodeassist.cpp index a679547..9dc8128 100644 --- a/qodeassist.cpp +++ b/qodeassist.cpp @@ -26,8 +26,9 @@ #include #include #include +#include +#include #include - #include #include @@ -41,11 +42,16 @@ #include "LLMProvidersManager.hpp" #include "PromptTemplateManager.hpp" #include "QodeAssistClient.hpp" +#include "chat/ChatOutputPane.h" #include "providers/LMStudioProvider.hpp" #include "providers/OllamaProvider.hpp" #include "providers/OpenAICompatProvider.hpp" -#include "templates/CodeLLamaTemplate.hpp" + +#include "settings/GeneralSettings.hpp" +#include "templates/CodeLlamaFimTemplate.hpp" +#include "templates/CodeLlamaInstruct.hpp" #include "templates/CustomTemplate.hpp" +#include "templates/DeepSeekCoderChatTemplate.hpp" #include "templates/DeepSeekCoderV2.hpp" #include "templates/StarCoder2Template.hpp" @@ -78,10 +84,12 @@ public: providerManager.registerProvider(); auto &templateManager = PromptTemplateManager::instance(); - templateManager.registerTemplate(); + templateManager.registerTemplate(); templateManager.registerTemplate(); templateManager.registerTemplate(); templateManager.registerTemplate(); + templateManager.registerTemplate(); + templateManager.registerTemplate(); Utils::Icon QCODEASSIST_ICON( {{":/resources/images/qoderassist-icon.png", Utils::Theme::IconsBaseColor}}); @@ -106,6 +114,8 @@ public: auto toggleButton = new QToolButton; toggleButton->setDefaultAction(requestAction.contextAction()); StatusBarManager::addStatusBarWidget(toggleButton, StatusBarManager::RightCorner); + + m_chatOutputPane = new Chat::ChatOutputPane(this); } void extensionsInitialized() final @@ -139,6 +149,7 @@ public: private: QPointer m_qodeAssistClient; + QPointer m_chatOutputPane; }; } // namespace QodeAssist::Internal diff --git a/settings/GeneralSettings.cpp b/settings/GeneralSettings.cpp index 25af1ec..58875c8 100644 --- a/settings/GeneralSettings.cpp +++ b/settings/GeneralSettings.cpp @@ -84,9 +84,8 @@ GeneralSettings::GeneralSettings() autoCompletionTypingInterval.setDefaultValue(2000); llmProviders.setSettingsKey(Constants::LLM_PROVIDERS); - llmProviders.setDisplayName(Tr::tr("FIM Provider:")); + llmProviders.setDisplayName(Tr::tr("AI Suggest Provider:")); llmProviders.setDisplayStyle(Utils::SelectionAspect::DisplayStyle::ComboBox); - llmProviders.setDefaultValue(0); url.setSettingsKey(Constants::URL); url.setLabelText(Tr::tr("URL:")); @@ -97,37 +96,53 @@ GeneralSettings::GeneralSettings() endPoint.setDisplayStyle(Utils::StringAspect::LineEditDisplay); modelName.setSettingsKey(Constants::MODEL_NAME); - modelName.setLabelText(Tr::tr("LLM Name:")); + modelName.setLabelText(Tr::tr("Model name:")); modelName.setDisplayStyle(Utils::StringAspect::LineEditDisplay); selectModels.m_buttonText = Tr::tr("Select Fill-In-the-Middle Model"); fimPrompts.setDisplayName(Tr::tr("Fill-In-the-Middle Prompt")); fimPrompts.setSettingsKey(Constants::FIM_PROMPTS); - fimPrompts.setDefaultValue(0); fimPrompts.setDisplayStyle(Utils::SelectionAspect::DisplayStyle::ComboBox); resetToDefaults.m_buttonText = Tr::tr("Reset Page to Defaults"); - const auto &manager = LLMProvidersManager::instance(); - if (!manager.getProviderNames().isEmpty()) { - const auto providerNames = manager.getProviderNames(); - for (const QString &name : providerNames) { - llmProviders.addOption(name); - } - } + chatLlmProviders.setSettingsKey(Constants::CHAT_LLM_PROVIDERS); + chatLlmProviders.setDisplayName(Tr::tr("AI Chat Provider:")); + chatLlmProviders.setDisplayStyle(Utils::SelectionAspect::DisplayStyle::ComboBox); - const auto &promptManager = PromptTemplateManager::instance(); - if (!promptManager.getTemplateNames().isEmpty()) { - const auto promptNames = promptManager.getTemplateNames(); - for (const QString &name : promptNames) { - fimPrompts.addOption(name); - } - } + chatUrl.setSettingsKey(Constants::CHAT_URL); + chatUrl.setLabelText(Tr::tr("URL:")); + chatUrl.setDisplayStyle(Utils::StringAspect::LineEditDisplay); + + chatEndPoint.setSettingsKey(Constants::CHAT_END_POINT); + chatEndPoint.setLabelText(Tr::tr("Chat Endpoint:")); + chatEndPoint.setDisplayStyle(Utils::StringAspect::LineEditDisplay); + + chatModelName.setSettingsKey(Constants::CHAT_MODEL_NAME); + chatModelName.setLabelText(Tr::tr("Model name:")); + chatModelName.setDisplayStyle(Utils::StringAspect::LineEditDisplay); + + chatSelectModels.m_buttonText = Tr::tr("Select Chat Model"); + + chatPrompts.setDisplayName(Tr::tr("Chat Prompt")); + chatPrompts.setSettingsKey(Constants::CHAT_PROMPTS); + chatPrompts.setDisplayStyle(Utils::SelectionAspect::DisplayStyle::ComboBox); + + loadProviders(); + loadPrompts(); readSettings(); - LLMProvidersManager::instance().setCurrentProvider(llmProviders.stringValue()); - PromptTemplateManager::instance().setCurrentTemplate(fimPrompts.stringValue()); + auto fimProviderName = llmProviders.displayForIndex(llmProviders.value()); + setCurrentFimProvider(fimProviderName); + auto chatProviderName = chatLlmProviders.displayForIndex(chatLlmProviders.value()); + setCurrentChatProvider(chatProviderName); + + auto nameFimPromts = fimPrompts.displayForIndex(fimPrompts.value()); + PromptTemplateManager::instance().setCurrentFimTemplate(nameFimPromts); + auto nameChatPromts = chatPrompts.displayForIndex(chatPrompts.value()); + PromptTemplateManager::instance().setCurrentChatTemplate(nameChatPromts); + setLoggingEnabled(enableLogging()); setupConnections(); @@ -135,23 +150,29 @@ GeneralSettings::GeneralSettings() setLayouter([this]() { using namespace Layouting; - auto rootLayout = Column{Row{enableQodeAssist, Stretch{1}, resetToDefaults}, - enableAutoComplete, - multiLineCompletion, - Row{autoCompletionCharThreshold, - autoCompletionTypingInterval, - startSuggestionTimer, - Stretch{1}}, - Space{8}, - enableLogging, - Space{8}, - Row{llmProviders, Stretch{1}}, - Row{url, endPoint, urlIndicator}, - Space{8}, - Row{selectModels, modelName, modelIndicator}, - Space{8}, - fimPrompts, - Stretch{1}}; + auto rootLayout + = Column{Row{enableQodeAssist, Stretch{1}, resetToDefaults}, + enableAutoComplete, + multiLineCompletion, + Row{autoCompletionCharThreshold, + autoCompletionTypingInterval, + startSuggestionTimer, + Stretch{1}}, + Space{8}, + enableLogging, + Space{8}, + Group{title(Tr::tr("AI Suggestions")), + Column{Row{llmProviders, Stretch{1}}, + Row{url, endPoint, fimUrlIndicator}, + Row{selectModels, modelName, fimModelIndicator}, + Row{fimPrompts, Stretch{1}}}}, + Space{16}, + Group{title(Tr::tr("AI Chat")), + Column{Row{chatLlmProviders, Stretch{1}}, + Row{chatUrl, chatEndPoint, chatUrlIndicator}, + Row{chatSelectModels, chatModelName, chatModelIndicator}, + Row{chatPrompts, Stretch{1}}}}, + Stretch{1}}; return rootLayout; }); @@ -161,17 +182,32 @@ GeneralSettings::GeneralSettings() void GeneralSettings::setupConnections() { connect(&llmProviders, &Utils::SelectionAspect::volatileValueChanged, this, [this]() { - int index = llmProviders.volatileValue(); - logMessage(QString("currentProvider %1").arg(llmProviders.displayForIndex(index))); - LLMProvidersManager::instance().setCurrentProvider(llmProviders.displayForIndex(index)); - updateProviderSettings(); + auto providerName = llmProviders.displayForIndex(llmProviders.volatileValue()); + setCurrentFimProvider(providerName); }); + connect(&chatLlmProviders, &Utils::SelectionAspect::volatileValueChanged, this, [this]() { + auto providerName = chatLlmProviders.displayForIndex(chatLlmProviders.volatileValue()); + setCurrentChatProvider(providerName); + }); + connect(&fimPrompts, &Utils::SelectionAspect::volatileValueChanged, this, [this]() { int index = fimPrompts.volatileValue(); - logMessage(QString("currentPrompt %1").arg(fimPrompts.displayForIndex(index))); - PromptTemplateManager::instance().setCurrentTemplate(fimPrompts.displayForIndex(index)); + PromptTemplateManager::instance().setCurrentFimTemplate(fimPrompts.displayForIndex(index)); }); - connect(&selectModels, &ButtonAspect::clicked, this, [this]() { showModelSelectionDialog(); }); + connect(&chatPrompts, &Utils::SelectionAspect::volatileValueChanged, this, [this]() { + int index = chatPrompts.volatileValue(); + PromptTemplateManager::instance().setCurrentChatTemplate(chatPrompts.displayForIndex(index)); + }); + + connect(&selectModels, &ButtonAspect::clicked, this, [this]() { + auto *provider = LLMProvidersManager::instance().getCurrentFimProvider(); + showModelSelectionDialog(&modelName, provider); + }); + connect(&chatSelectModels, &ButtonAspect::clicked, this, [this]() { + auto *provider = LLMProvidersManager::instance().getCurrentChatProvider(); + showModelSelectionDialog(&chatModelName, provider); + }); + connect(&enableLogging, &Utils::BoolAspect::volatileValueChanged, this, [this]() { setLoggingEnabled(enableLogging.volatileValue()); }); @@ -185,22 +221,19 @@ void GeneralSettings::setupConnections() &Utils::StringAspect::volatileValueChanged, this, &GeneralSettings::updateStatusIndicators); + connect(&chatUrl, + &Utils::StringAspect::volatileValueChanged, + this, + &GeneralSettings::updateStatusIndicators); + connect(&chatModelName, + &Utils::StringAspect::volatileValueChanged, + this, + &GeneralSettings::updateStatusIndicators); } -void GeneralSettings::updateProviderSettings() +void GeneralSettings::showModelSelectionDialog(Utils::StringAspect *modelNameObj, + Providers::LLMProvider *provider) { - const auto provider = LLMProvidersManager::instance().getCurrentProvider(); - - if (provider) { - url.setVolatileValue(provider->url()); - endPoint.setVolatileValue(provider->completionEndpoint()); - modelName.setVolatileValue(""); - } -} - -void GeneralSettings::showModelSelectionDialog() -{ - auto *provider = LLMProvidersManager::instance().getCurrentProvider(); Utils::Environment env = Utils::Environment::systemEnvironment(); if (provider) { @@ -215,7 +248,7 @@ void GeneralSettings::showModelSelectionDialog() &ok); if (ok && !selectedModel.isEmpty()) { - modelName.setVolatileValue(selectedModel); + modelNameObj->setVolatileValue(selectedModel); writeSettings(); } } @@ -233,42 +266,58 @@ void GeneralSettings::resetPageToDefaults() if (reply == QMessageBox::Yes) { resetAspect(enableQodeAssist); resetAspect(enableAutoComplete); - resetAspect(llmProviders); - resetAspect(url); - resetAspect(endPoint); - resetAspect(modelName); - resetAspect(fimPrompts); resetAspect(enableLogging); resetAspect(startSuggestionTimer); resetAspect(autoCompletionTypingInterval); resetAspect(autoCompletionCharThreshold); } - fimPrompts.setStringValue("StarCoder2"); - llmProviders.setStringValue("Ollama"); + int fimIndex = llmProviders.indexForDisplay("Ollama"); + llmProviders.setVolatileValue(fimIndex); + int chatIndex = chatLlmProviders.indexForDisplay("Ollama"); + chatLlmProviders.setVolatileValue(chatIndex); + modelName.setVolatileValue(""); + chatModelName.setVolatileValue(""); + updateStatusIndicators(); } void GeneralSettings::updateStatusIndicators() { - bool urlValid = !url.volatileValue().isEmpty() && !endPoint.volatileValue().isEmpty(); - bool modelValid = !modelName.volatileValue().isEmpty(); + bool fimUrlValid = !url.volatileValue().isEmpty() && !endPoint.volatileValue().isEmpty(); + bool fimModelValid = !modelName.volatileValue().isEmpty(); + bool chatUrlValid = !chatUrl.volatileValue().isEmpty() + && !chatEndPoint.volatileValue().isEmpty(); + bool chatModelValid = !chatModelName.volatileValue().isEmpty(); - bool pingSuccessful = false; - if (urlValid) { + bool fimPingSuccessful = false; + if (fimUrlValid) { QUrl pingUrl(url.volatileValue()); - pingSuccessful = QodeAssist::pingUrl(pingUrl); + fimPingSuccessful = QodeAssist::pingUrl(pingUrl); + } + bool chatPingSuccessful = false; + if (chatUrlValid) { + QUrl pingUrl(chatUrl.volatileValue()); + chatPingSuccessful = QodeAssist::pingUrl(pingUrl); } - setIndicatorStatus(modelIndicator, - modelValid ? tr("Model is properly configured") - : tr("No model selected or model name is invalid"), - modelValid); + setIndicatorStatus(fimModelIndicator, + fimModelValid ? tr("Model is properly configured") + : tr("No model selected or model name is invalid"), + fimModelValid); + setIndicatorStatus(fimUrlIndicator, + fimPingSuccessful ? tr("Server is reachable") + : tr("Server is not reachable or URL is invalid"), + fimPingSuccessful); - setIndicatorStatus(urlIndicator, - pingSuccessful ? tr("Server is reachable") - : tr("Server is not reachable or URL is invalid"), - pingSuccessful); + setIndicatorStatus(chatModelIndicator, + chatModelValid ? tr("Model is properly configured") + : tr("No model selected or model name is invalid"), + chatModelValid); + setIndicatorStatus(chatUrlIndicator, + chatPingSuccessful ? tr("Server is reachable") + : tr("Server is not reachable or URL is invalid"), + chatPingSuccessful); } void GeneralSettings::setIndicatorStatus(Utils::StringAspect &indicator, @@ -280,6 +329,44 @@ void GeneralSettings::setIndicatorStatus(Utils::StringAspect &indicator, indicator.setToolTip(tooltip); } +void GeneralSettings::setCurrentFimProvider(const QString &name) +{ + const auto provider = LLMProvidersManager::instance().setCurrentFimProvider(name); + if (!provider) + return; + + url.setValue(provider->url()); + endPoint.setValue(provider->completionEndpoint()); +} + +void GeneralSettings::setCurrentChatProvider(const QString &name) +{ + const auto provider = LLMProvidersManager::instance().setCurrentChatProvider(name); + if (!provider) + return; + + chatUrl.setValue(provider->url()); + chatEndPoint.setValue(provider->chatEndpoint()); +} + +void GeneralSettings::loadProviders() +{ + for (const auto &name : LLMProvidersManager::instance().providersNames()) { + llmProviders.addOption(name); + chatLlmProviders.addOption(name); + } +} + +void GeneralSettings::loadPrompts() +{ + for (const auto &name : PromptTemplateManager::instance().fimTemplatesNames()) { + fimPrompts.addOption(name); + } + for (const auto &name : PromptTemplateManager::instance().chatTemplatesNames()) { + chatPrompts.addOption(name); + } +} + class GeneralSettingsPage : public Core::IOptionsPage { public: diff --git a/settings/GeneralSettings.hpp b/settings/GeneralSettings.hpp index 709954e..804d609 100644 --- a/settings/GeneralSettings.hpp +++ b/settings/GeneralSettings.hpp @@ -21,6 +21,7 @@ #include +#include "providers/LLMProvider.hpp" #include "settings/SettingsUtils.hpp" namespace QodeAssist::Settings { @@ -47,17 +48,33 @@ public: Utils::SelectionAspect fimPrompts{this}; ButtonAspect resetToDefaults{this}; - Utils::StringAspect modelIndicator{this}; - Utils::StringAspect urlIndicator{this}; + Utils::SelectionAspect chatLlmProviders{this}; + Utils::StringAspect chatUrl{this}; + Utils::StringAspect chatEndPoint{this}; + + Utils::StringAspect chatModelName{this}; + ButtonAspect chatSelectModels{this}; + Utils::SelectionAspect chatPrompts{this}; + + Utils::StringAspect fimModelIndicator{this}; + Utils::StringAspect fimUrlIndicator{this}; + Utils::StringAspect chatModelIndicator{this}; + Utils::StringAspect chatUrlIndicator{this}; private: void setupConnections(); - void updateProviderSettings(); - void showModelSelectionDialog(); + void showModelSelectionDialog(Utils::StringAspect *modelNameObj, + Providers::LLMProvider *provider); void resetPageToDefaults(); void updateStatusIndicators(); void setIndicatorStatus(Utils::StringAspect &indicator, const QString &tooltip, bool isValid); + + void setCurrentFimProvider(const QString &name); + void setCurrentChatProvider(const QString &name); + + void loadProviders(); + void loadPrompts(); }; GeneralSettings &generalSettings(); diff --git a/templates/CodeLLamaTemplate.hpp b/templates/CodeLlamaFimTemplate.hpp similarity index 88% rename from templates/CodeLLamaTemplate.hpp rename to templates/CodeLlamaFimTemplate.hpp index 6115d28..48fedfc 100644 --- a/templates/CodeLLamaTemplate.hpp +++ b/templates/CodeLlamaFimTemplate.hpp @@ -23,10 +23,11 @@ namespace QodeAssist::Templates { -class CodeLLamaTemplate : public PromptTemplate +class CodeLlamaFimTemplate : public PromptTemplate { public: - QString name() const override { return "CodeLlama"; } + TemplateType type() const override { return TemplateType::Fim; } + QString name() const override { return "CodeLlama FIM"; } QString promptTemplate() const override { return "%1
 %2 %3 "; }
     QStringList stopWords() const override
     {
diff --git a/templates/CodeLlamaInstruct.hpp b/templates/CodeLlamaInstruct.hpp
new file mode 100644
index 0000000..96d1f22
--- /dev/null
+++ b/templates/CodeLlamaInstruct.hpp
@@ -0,0 +1,49 @@
+/* 
+ * Copyright (C) 2024 Petr Mironychev
+ *
+ * This file is part of QodeAssist.
+ *
+ * QodeAssist is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * QodeAssist is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with QodeAssist. If not, see .
+ */
+
+#pragma once
+
+#include 
+#include "PromptTemplate.hpp"
+
+namespace QodeAssist::Templates {
+
+class CodeLlamaInstructTemplate : public PromptTemplate
+{
+public:
+    TemplateType type() const override { return TemplateType::Chat; }
+    QString name() const override { return "CodeLlama Chat"; }
+    QString promptTemplate() const override { return "[INST] %1 [/INST]"; }
+    QStringList stopWords() const override { return QStringList() << "[INST]" << "[/INST]"; }
+
+    void prepareRequest(QJsonObject &request, const ContextData &context) const override
+    {
+        QString formattedPrompt = promptTemplate().arg(context.prefix);
+        QJsonArray messages = request["messages"].toArray();
+
+        QJsonObject newMessage;
+        newMessage["role"] = "user";
+        newMessage["content"] = formattedPrompt;
+        messages.append(newMessage);
+
+        request["messages"] = messages;
+    }
+};
+
+} // namespace QodeAssist::Templates
diff --git a/templates/CustomTemplate.hpp b/templates/CustomTemplate.hpp
index 5ae2e24..5e1b913 100644
--- a/templates/CustomTemplate.hpp
+++ b/templates/CustomTemplate.hpp
@@ -32,7 +32,8 @@ namespace QodeAssist::Templates {
 class CustomTemplate : public PromptTemplate
 {
 public:
-    QString name() const override { return "Custom Template"; }
+    TemplateType type() const override { return TemplateType::Fim; }
+    QString name() const override { return "Custom FIM Template"; }
     QString promptTemplate() const override
     {
         return Settings::customPromptSettings().customJsonTemplate();
diff --git a/templates/DeepSeekCoderChatTemplate.hpp b/templates/DeepSeekCoderChatTemplate.hpp
new file mode 100644
index 0000000..775e614
--- /dev/null
+++ b/templates/DeepSeekCoderChatTemplate.hpp
@@ -0,0 +1,54 @@
+/* 
+ * Copyright (C) 2024 Petr Mironychev
+ *
+ * This file is part of QodeAssist.
+ *
+ * QodeAssist is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * QodeAssist is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with QodeAssist. If not, see .
+ */
+
+#pragma once
+
+#include 
+#include "PromptTemplate.hpp"
+
+namespace QodeAssist::Templates {
+
+class DeepSeekCoderChatTemplate : public PromptTemplate
+{
+public:
+    QString name() const override { return "DeepSeek Coder Chat"; }
+    TemplateType type() const override { return TemplateType::Chat; }
+
+    QString promptTemplate() const override { return "### Instruction:\n%1\n### Response:\n"; }
+
+    QStringList stopWords() const override
+    {
+        return QStringList() << "### Instruction:" << "### Response:" << "\n\n### " << "<|EOT|>";
+    }
+
+    void prepareRequest(QJsonObject &request, const ContextData &context) const override
+    {
+        QString formattedPrompt = promptTemplate().arg(context.prefix);
+        QJsonArray messages = request["messages"].toArray();
+
+        QJsonObject newMessage;
+        newMessage["role"] = "user";
+        newMessage["content"] = formattedPrompt;
+        messages.append(newMessage);
+
+        request["messages"] = messages;
+    }
+};
+
+} // namespace QodeAssist::Templates
diff --git a/templates/DeepSeekCoderV2.hpp b/templates/DeepSeekCoderV2.hpp
index 69886fb..4273a2e 100644
--- a/templates/DeepSeekCoderV2.hpp
+++ b/templates/DeepSeekCoderV2.hpp
@@ -26,7 +26,8 @@ namespace QodeAssist::Templates {
 class DeepSeekCoderV2Template : public PromptTemplate
 {
 public:
-    QString name() const override { return "DeepSeekCoderV2"; }
+    TemplateType type() const override { return TemplateType::Fim; }
+    QString name() const override { return "DeepSeekCoder FIM"; }
     QString promptTemplate() const override
     {
         return "%1<|fim▁begin|>%2<|fim▁hole|>%3<|fim▁end|>";
diff --git a/templates/PromptTemplate.hpp b/templates/PromptTemplate.hpp
index 5b134ee..0dee78e 100644
--- a/templates/PromptTemplate.hpp
+++ b/templates/PromptTemplate.hpp
@@ -27,10 +27,13 @@
 
 namespace QodeAssist::Templates {
 
+enum class TemplateType { Chat, Fim };
+
 class PromptTemplate
 {
 public:
     virtual ~PromptTemplate() = default;
+    virtual TemplateType type() const = 0;
     virtual QString name() const = 0;
     virtual QString promptTemplate() const = 0;
     virtual QStringList stopWords() const = 0;
diff --git a/templates/StarCoder2Template.hpp b/templates/StarCoder2Template.hpp
index 8f89679..4d61c88 100644
--- a/templates/StarCoder2Template.hpp
+++ b/templates/StarCoder2Template.hpp
@@ -26,7 +26,8 @@ namespace QodeAssist::Templates {
 class StarCoder2Template : public PromptTemplate
 {
 public:
-    QString name() const override { return "StarCoder2"; }
+    TemplateType type() const override { return TemplateType::Fim; }
+    QString name() const override { return "StarCoder2 FIM"; }
     QString promptTemplate() const override { return "%1%2%3"; }
     QStringList stopWords() const override
     {