diff --git a/CMakeLists.txt b/CMakeLists.txt index 5fdeeba..5de4aa3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -43,26 +43,28 @@ add_qtc_plugin(QodeAssist LLMClientInterface.hpp LLMClientInterface.cpp templates/Templates.hpp templates/CodeLlamaFim.hpp - templates/StarCoder2Fim.hpp - templates/DeepSeekCoderFim.hpp - templates/CustomFimTemplate.hpp - templates/Qwen.hpp templates/Ollama.hpp - templates/BasicChat.hpp + templates/Claude.hpp + templates/OpenAI.hpp + templates/MistralAI.hpp + templates/StarCoder2Fim.hpp + # templates/DeepSeekCoderFim.hpp + # templates/CustomFimTemplate.hpp + templates/Qwen.hpp + templates/OpenAICompatible.hpp templates/Llama3.hpp templates/ChatML.hpp templates/Alpaca.hpp templates/Llama2.hpp - templates/Claude.hpp - templates/OpenAI.hpp templates/CodeLlamaQMLFim.hpp providers/Providers.hpp providers/OllamaProvider.hpp providers/OllamaProvider.cpp + providers/ClaudeProvider.hpp providers/ClaudeProvider.cpp + providers/OpenAIProvider.hpp providers/OpenAIProvider.cpp + providers/MistralAIProvider.hpp providers/MistralAIProvider.cpp providers/LMStudioProvider.hpp providers/LMStudioProvider.cpp providers/OpenAICompatProvider.hpp providers/OpenAICompatProvider.cpp providers/OpenRouterAIProvider.hpp providers/OpenRouterAIProvider.cpp - providers/ClaudeProvider.hpp providers/ClaudeProvider.cpp - providers/OpenAIProvider.hpp providers/OpenAIProvider.cpp QodeAssist.qrc LSPCompletion.hpp LLMSuggestion.hpp LLMSuggestion.cpp diff --git a/ChatView/ClientInterface.cpp b/ChatView/ClientInterface.cpp index 8de7186..c264fb0 100644 --- a/ChatView/ClientInterface.cpp +++ b/ChatView/ClientInterface.cpp @@ -93,51 +93,35 @@ void ClientInterface::sendMessage( } LLMCore::ContextData context; - context.prefix = message; - context.suffix = ""; - QString systemPrompt; - if (chatAssistantSettings.useSystemPrompt()) - systemPrompt = chatAssistantSettings.systemPrompt(); - - if (!linkedFiles.isEmpty()) { - systemPrompt = getSystemPromptWithLinkedFiles(systemPrompt, linkedFiles); + if (chatAssistantSettings.useSystemPrompt()) { + QString systemPrompt = chatAssistantSettings.systemPrompt(); + if (!linkedFiles.isEmpty()) { + systemPrompt = getSystemPromptWithLinkedFiles(systemPrompt, linkedFiles); + } + context.systemPrompt = systemPrompt; } - QJsonObject providerRequest; - providerRequest["model"] = Settings::generalSettings().caModel(); - providerRequest["stream"] = chatAssistantSettings.stream(); - providerRequest["messages"] = m_chatModel->prepareMessagesForRequest(systemPrompt); - - if (promptTemplate) - promptTemplate->prepareRequest(providerRequest, context); - else - qWarning("No prompt template found"); - - if (provider) - provider->prepareRequest(providerRequest, LLMCore::RequestType::Chat); - else - qWarning("No provider found"); + QVector messages; + for (const auto &msg : m_chatModel->getChatHistory()) { + messages.append({msg.role == ChatModel::ChatRole::User ? "user" : "assistant", msg.content}); + } + context.history = messages; LLMCore::LLMConfig config; config.requestType = LLMCore::RequestType::Chat; config.provider = provider; config.promptTemplate = promptTemplate; config.url = QString("%1%2").arg(Settings::generalSettings().caUrl(), provider->chatEndpoint()); - config.providerRequest = providerRequest; - config.multiLineCompletion = false; config.apiKey = provider->apiKey(); + config.providerRequest + = {{"model", Settings::generalSettings().caModel()}, + {"stream", chatAssistantSettings.stream()}}; - QJsonObject request; - request["id"] = QUuid::createUuid().toString(); - - auto errors = config.provider->validateRequest(config.providerRequest, promptTemplate->type()); - if (!errors.isEmpty()) { - LOG_MESSAGE("Validate errors for chat request:"); - LOG_MESSAGES(errors); - return; - } + config.provider + ->prepareRequest(config.providerRequest, promptTemplate, context, LLMCore::RequestType::Chat); + QJsonObject request{{"id", QUuid::createUuid().toString()}}; m_requestHandler->sendLLMRequest(config, request); } diff --git a/LLMClientInterface.cpp b/LLMClientInterface.cpp index c69a17c..7602773 100644 --- a/LLMClientInterface.cpp +++ b/LLMClientInterface.cpp @@ -28,7 +28,6 @@ #include "CodeHandler.hpp" #include "context/DocumentContextReader.hpp" -#include "llmcore/MessageBuilder.hpp" #include "llmcore/PromptTemplateManager.hpp" #include "llmcore/ProvidersManager.hpp" #include "logger/Logger.hpp" @@ -159,7 +158,7 @@ bool QodeAssist::LLMClientInterface::isSpecifyCompletion(const QJsonObject &requ void LLMClientInterface::handleCompletion(const QJsonObject &request) { - const auto updatedContext = prepareContext(request); + auto updatedContext = prepareContext(request); auto &completeSettings = Settings::codeCompletionSettings(); auto &generalSettings = Settings::generalSettings(); @@ -196,7 +195,7 @@ void LLMClientInterface::handleCompletion(const QJsonObject &request) config.promptTemplate = promptTemplate; config.url = QUrl(QString("%1%2").arg( url, - promptTemplate->type() == LLMCore::TemplateType::Fim ? provider->completionEndpoint() + promptTemplate->type() == LLMCore::TemplateType::FIM ? provider->completionEndpoint() : provider->chatEndpoint())); config.apiKey = provider->apiKey(); @@ -211,29 +210,30 @@ void LLMClientInterface::handleCompletion(const QJsonObject &request) QString systemPrompt; if (completeSettings.useSystemPrompt()) systemPrompt.append(completeSettings.systemPrompt()); - if (!updatedContext.fileContext.isEmpty()) - systemPrompt.append(updatedContext.fileContext); + if (updatedContext.fileContext.has_value()) + systemPrompt.append(updatedContext.fileContext.value()); - QString userMessage; - if (completeSettings.useUserMessageTemplateForCC() - && promptTemplate->type() == LLMCore::TemplateType::Chat) { - userMessage - = completeSettings.processMessageToFIM(updatedContext.prefix, updatedContext.suffix); - } else { - userMessage = updatedContext.prefix; + updatedContext.systemPrompt = systemPrompt; + + if (promptTemplate->type() == LLMCore::TemplateType::Chat) { + QString userMessage; + if (completeSettings.useUserMessageTemplateForCC()) { + userMessage = completeSettings.processMessageToFIM( + updatedContext.prefix.value_or(""), updatedContext.suffix.value_or("")); + } else { + userMessage = updatedContext.prefix.value_or("") + updatedContext.suffix.value_or(""); + } + + QVector messages; + messages.append({"user", userMessage}); + updatedContext.history = messages; } - auto message = LLMCore::MessageBuilder() - .addSystemMessage(systemPrompt) - .addUserMessage(userMessage) - .addSuffix(updatedContext.suffix) - .addTokenizer(promptTemplate); - - message.saveTo( + config.provider->prepareRequest( config.providerRequest, - providerName == "Ollama" ? LLMCore::ProvidersApi::Ollama : LLMCore::ProvidersApi::OpenAI); - - config.provider->prepareRequest(config.providerRequest, LLMCore::RequestType::CodeCompletion); + promptTemplate, + updatedContext, + LLMCore::RequestType::CodeCompletion); auto errors = config.provider->validateRequest(config.providerRequest, promptTemplate->type()); if (!errors.isEmpty()) { diff --git a/context/DocumentContextReader.cpp b/context/DocumentContextReader.cpp index 81c6aa9..feec2ab 100644 --- a/context/DocumentContextReader.cpp +++ b/context/DocumentContextReader.cpp @@ -216,7 +216,7 @@ LLMCore::ContextData DocumentContextReader::prepareContext(int lineNumber, int c fileContext.append("\n ").append( ChangesManager::instance().getRecentChangesContext(m_textDocument)); - return {contextBefore, contextAfter, fileContext}; + return {.prefix = contextBefore, .suffix = contextAfter, .fileContext = fileContext}; } QString DocumentContextReader::getContextBefore(int lineNumber, int cursorPosition) const diff --git a/llmcore/CMakeLists.txt b/llmcore/CMakeLists.txt index 93d7bcf..4a8e0d9 100644 --- a/llmcore/CMakeLists.txt +++ b/llmcore/CMakeLists.txt @@ -10,7 +10,6 @@ add_library(LLMCore STATIC OllamaMessage.hpp OllamaMessage.cpp OpenAIMessage.hpp OpenAIMessage.cpp ValidationUtils.hpp ValidationUtils.cpp - MessageBuilder.hpp MessageBuilder.cpp ) target_link_libraries(LLMCore diff --git a/llmcore/ContextData.hpp b/llmcore/ContextData.hpp index 4d0ea92..1515540 100644 --- a/llmcore/ContextData.hpp +++ b/llmcore/ContextData.hpp @@ -20,14 +20,23 @@ #pragma once #include +#include namespace QodeAssist::LLMCore { +struct Message +{ + QString role; + QString content; +}; + struct ContextData { - QString prefix; - QString suffix; - QString fileContext; + std::optional systemPrompt; + std::optional prefix; + std::optional suffix; + std::optional fileContext; + std::optional> history; }; } // namespace QodeAssist::LLMCore diff --git a/llmcore/MessageBuilder.cpp b/llmcore/MessageBuilder.cpp deleted file mode 100644 index cf25fde..0000000 --- a/llmcore/MessageBuilder.cpp +++ /dev/null @@ -1,93 +0,0 @@ -/* - * Copyright (C) 2024 Petr Mironychev - * - * This file is part of QodeAssist. - * - * QodeAssist is free software: you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * QodeAssist is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with QodeAssist. If not, see . - */ - -#include "MessageBuilder.hpp" - -QodeAssist::LLMCore::MessageBuilder &QodeAssist::LLMCore::MessageBuilder::addSystemMessage( - const QString &content) -{ - m_systemMessage = content; - return *this; -} - -QodeAssist::LLMCore::MessageBuilder &QodeAssist::LLMCore::MessageBuilder::addUserMessage( - const QString &content) -{ - m_messages.append({MessageRole::User, content}); - return *this; -} - -QodeAssist::LLMCore::MessageBuilder &QodeAssist::LLMCore::MessageBuilder::addSuffix( - const QString &content) -{ - m_suffix = content; - return *this; -} - -QodeAssist::LLMCore::MessageBuilder &QodeAssist::LLMCore::MessageBuilder::addTokenizer( - PromptTemplate *promptTemplate) -{ - m_promptTemplate = promptTemplate; - return *this; -} - -QString QodeAssist::LLMCore::MessageBuilder::roleToString(MessageRole role) const -{ - switch (role) { - case MessageRole::System: - return ROLE_SYSTEM; - case MessageRole::User: - return ROLE_USER; - case MessageRole::Assistant: - return ROLE_ASSISTANT; - default: - return ROLE_USER; - } -} - -void QodeAssist::LLMCore::MessageBuilder::saveTo(QJsonObject &request, ProvidersApi api) -{ - if (!m_promptTemplate) { - return; - } - - ContextData context{ - m_messages.isEmpty() ? QString() : m_messages.last().content, m_suffix, m_systemMessage}; - - if (api == ProvidersApi::Ollama) { - if (m_promptTemplate->type() == TemplateType::Fim) { - request["system"] = m_systemMessage; - m_promptTemplate->prepareRequest(request, context); - } else { - QJsonArray messages; - - messages.append(QJsonObject{{"role", "system"}, {"content", m_systemMessage}}); - messages.append(QJsonObject{{"role", "user"}, {"content", m_messages.last().content}}); - request["messages"] = messages; - m_promptTemplate->prepareRequest(request, context); - } - } else if (api == ProvidersApi::OpenAI) { - QJsonArray messages; - - messages.append(QJsonObject{{"role", "system"}, {"content", m_systemMessage}}); - messages.append(QJsonObject{{"role", "user"}, {"content", m_messages.last().content}}); - request["messages"] = messages; - m_promptTemplate->prepareRequest(request, context); - } -} diff --git a/llmcore/MessageBuilder.hpp b/llmcore/MessageBuilder.hpp deleted file mode 100644 index f37bf6c..0000000 --- a/llmcore/MessageBuilder.hpp +++ /dev/null @@ -1,68 +0,0 @@ -/* - * Copyright (C) 2024 Petr Mironychev - * - * This file is part of QodeAssist. - * - * QodeAssist is free software: you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * QodeAssist is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with QodeAssist. If not, see . - */ - -#pragma once - -#include -#include -#include -#include - -#include "PromptTemplate.hpp" - -namespace QodeAssist::LLMCore { - -enum class MessageRole { System, User, Assistant }; - -enum class OllamaFormat { Messages, Completions }; - -enum class ProvidersApi { Ollama, OpenAI, Claude }; - -static const QString ROLE_SYSTEM = "system"; -static const QString ROLE_USER = "user"; -static const QString ROLE_ASSISTANT = "assistant"; - -struct Message -{ - MessageRole role; - QString content; -}; - -class MessageBuilder -{ -public: - MessageBuilder &addSystemMessage(const QString &content); - - MessageBuilder &addUserMessage(const QString &content); - - MessageBuilder &addSuffix(const QString &content); - - MessageBuilder &addTokenizer(PromptTemplate *promptTemplate); - - QString roleToString(MessageRole role) const; - - void saveTo(QJsonObject &request, ProvidersApi api); - -private: - QString m_systemMessage; - QString m_suffix; - QVector m_messages; - PromptTemplate *m_promptTemplate; -}; -} // namespace QodeAssist::LLMCore diff --git a/llmcore/PromptTemplate.hpp b/llmcore/PromptTemplate.hpp index 5d7ff2b..95cbe56 100644 --- a/llmcore/PromptTemplate.hpp +++ b/llmcore/PromptTemplate.hpp @@ -27,7 +27,7 @@ namespace QodeAssist::LLMCore { -enum class TemplateType { Chat, Fim }; +enum class TemplateType { Chat, FIM }; class PromptTemplate { @@ -35,7 +35,7 @@ public: virtual ~PromptTemplate() = default; virtual TemplateType type() const = 0; virtual QString name() const = 0; - virtual QString promptTemplate() const = 0; + // virtual QString promptTemplate() const = 0; virtual QStringList stopWords() const = 0; virtual void prepareRequest(QJsonObject &request, const ContextData &context) const = 0; virtual QString description() const = 0; diff --git a/llmcore/Provider.hpp b/llmcore/Provider.hpp index 9ff1451..a33d76e 100644 --- a/llmcore/Provider.hpp +++ b/llmcore/Provider.hpp @@ -23,6 +23,7 @@ #include #include +#include "ContextData.hpp" #include "PromptTemplate.hpp" #include "RequestType.hpp" @@ -42,7 +43,12 @@ public: virtual QString chatEndpoint() const = 0; virtual bool supportsModelListing() const = 0; - virtual void prepareRequest(QJsonObject &request, RequestType type) = 0; + virtual void prepareRequest( + QJsonObject &request, + LLMCore::PromptTemplate *prompt, + LLMCore::ContextData context, + LLMCore::RequestType type) + = 0; virtual bool handleResponse(QNetworkReply *reply, QString &accumulatedResponse) = 0; virtual QList getInstalledModels(const QString &url) = 0; virtual QList validateRequest(const QJsonObject &request, TemplateType type) = 0; diff --git a/llmcore/RequestHandler.cpp b/llmcore/RequestHandler.cpp index b57d349..81dc8f4 100644 --- a/llmcore/RequestHandler.cpp +++ b/llmcore/RequestHandler.cpp @@ -58,7 +58,10 @@ void RequestHandler::sendLLMRequest(const LLMConfig &config, const QJsonObject & reply->deleteLater(); m_activeRequests.remove(requestId); if (reply->error() != QNetworkReply::NoError) { - LOG_MESSAGE(QString("Error in QodeAssist request: %1").arg(reply->errorString())); + LOG_MESSAGE(QString("Error details: %1\nStatus code: %2\nResponse: %3") + .arg(reply->errorString()) + .arg(reply->attribute(QNetworkRequest::HttpStatusCodeAttribute).toInt()) + .arg(QString(reply->readAll()))); emit requestFinished(requestId, false, reply->errorString()); } else { LOG_MESSAGE("Request finished successfully"); diff --git a/llmcore/RequestType.hpp b/llmcore/RequestType.hpp index fe9de64..8647369 100644 --- a/llmcore/RequestType.hpp +++ b/llmcore/RequestType.hpp @@ -21,5 +21,5 @@ namespace QodeAssist::LLMCore { -enum RequestType { CodeCompletion, Chat }; +enum RequestType { CodeCompletion, Chat, Embedding }; } diff --git a/providers/ClaudeProvider.cpp b/providers/ClaudeProvider.cpp index 426024c..e41e67e 100644 --- a/providers/ClaudeProvider.cpp +++ b/providers/ClaudeProvider.cpp @@ -30,13 +30,10 @@ #include "logger/Logger.hpp" #include "settings/ChatAssistantSettings.hpp" #include "settings/CodeCompletionSettings.hpp" -#include "settings/GeneralSettings.hpp" #include "settings/ProviderSettings.hpp" namespace QodeAssist::Providers { -ClaudeProvider::ClaudeProvider() {} - QString ClaudeProvider::name() const { return "Claude"; @@ -62,31 +59,17 @@ bool ClaudeProvider::supportsModelListing() const return true; } -void ClaudeProvider::prepareRequest(QJsonObject &request, LLMCore::RequestType type) +void ClaudeProvider::prepareRequest( + QJsonObject &request, + LLMCore::PromptTemplate *prompt, + LLMCore::ContextData context, + LLMCore::RequestType type) { - auto prepareMessages = [](QJsonObject &req) -> QJsonArray { - QJsonArray messages; - if (req.contains("messages")) { - QJsonArray origMessages = req["messages"].toArray(); - for (const auto &msg : origMessages) { - QJsonObject message = msg.toObject(); - if (message["role"].toString() == "system") { - req["system"] = message["content"]; - } else { - messages.append(message); - } - } - } else { - if (req.contains("system")) { - req["system"] = req["system"].toString(); - } - if (req.contains("prompt")) { - messages.append( - QJsonObject{{"role", "user"}, {"content", req.take("prompt").toString()}}); - } - } - return messages; - }; + // if (!isSupportedTemplate(prompt->name())) { + // LOG_MESSAGE(QString("Provider doesn't support %1 template").arg(prompt->name())); + // } + + prompt->prepareRequest(request, context); auto applyModelParams = [&request](const auto &settings) { request["max_tokens"] = settings.maxTokens(); @@ -98,11 +81,6 @@ void ClaudeProvider::prepareRequest(QJsonObject &request, LLMCore::RequestType t request["stream"] = true; }; - QJsonArray messages = prepareMessages(request); - if (!messages.isEmpty()) { - request["messages"] = std::move(messages); - } - if (type == LLMCore::RequestType::CodeCompletion) { applyModelParams(Settings::codeCompletionSettings()); } else { diff --git a/providers/ClaudeProvider.hpp b/providers/ClaudeProvider.hpp index 9d55cc3..55c68b1 100644 --- a/providers/ClaudeProvider.hpp +++ b/providers/ClaudeProvider.hpp @@ -26,14 +26,16 @@ namespace QodeAssist::Providers { class ClaudeProvider : public LLMCore::Provider { public: - ClaudeProvider(); - QString name() const override; QString url() const override; QString completionEndpoint() const override; QString chatEndpoint() const override; bool supportsModelListing() const override; - void prepareRequest(QJsonObject &request, LLMCore::RequestType type) override; + void prepareRequest( + QJsonObject &request, + LLMCore::PromptTemplate *prompt, + LLMCore::ContextData context, + LLMCore::RequestType type) override; bool handleResponse(QNetworkReply *reply, QString &accumulatedResponse) override; QList getInstalledModels(const QString &url) override; QList validateRequest(const QJsonObject &request, LLMCore::TemplateType type) override; diff --git a/providers/LMStudioProvider.cpp b/providers/LMStudioProvider.cpp index 168cbd3..af6ab9e 100644 --- a/providers/LMStudioProvider.cpp +++ b/providers/LMStudioProvider.cpp @@ -33,8 +33,6 @@ namespace QodeAssist::Providers { -LMStudioProvider::LMStudioProvider() {} - QString LMStudioProvider::name() const { return "LM Studio"; @@ -60,47 +58,6 @@ bool LMStudioProvider::supportsModelListing() const return true; } -void LMStudioProvider::prepareRequest(QJsonObject &request, LLMCore::RequestType type) -{ - auto prepareMessages = [](QJsonObject &req) -> QJsonArray { - QJsonArray messages; - if (req.contains("system")) { - messages.append( - QJsonObject{{"role", "system"}, {"content", req.take("system").toString()}}); - } - if (req.contains("prompt")) { - messages.append( - QJsonObject{{"role", "user"}, {"content", req.take("prompt").toString()}}); - } - return messages; - }; - - auto applyModelParams = [&request](const auto &settings) { - request["max_tokens"] = settings.maxTokens(); - request["temperature"] = settings.temperature(); - - if (settings.useTopP()) - request["top_p"] = settings.topP(); - if (settings.useTopK()) - request["top_k"] = settings.topK(); - if (settings.useFrequencyPenalty()) - request["frequency_penalty"] = settings.frequencyPenalty(); - if (settings.usePresencePenalty()) - request["presence_penalty"] = settings.presencePenalty(); - }; - - QJsonArray messages = prepareMessages(request); - if (!messages.isEmpty()) { - request["messages"] = std::move(messages); - } - - if (type == LLMCore::RequestType::CodeCompletion) { - applyModelParams(Settings::codeCompletionSettings()); - } else { - applyModelParams(Settings::chatAssistantSettings()); - } -} - bool LMStudioProvider::handleResponse(QNetworkReply *reply, QString &accumulatedResponse) { QByteArray data = reply->readAll(); @@ -211,4 +168,37 @@ void LMStudioProvider::prepareNetworkRequest(QNetworkRequest &networkRequest) co networkRequest.setHeader(QNetworkRequest::ContentTypeHeader, "application/json"); } +void QodeAssist::Providers::LMStudioProvider::prepareRequest( + QJsonObject &request, + LLMCore::PromptTemplate *prompt, + LLMCore::ContextData context, + LLMCore::RequestType type) +{ + // if (!isSupportedTemplate(prompt->name())) { + // LOG_MESSAGE(QString("Provider doesn't support %1 template").arg(prompt->name())); + // } + + prompt->prepareRequest(request, context); + + auto applyModelParams = [&request](const auto &settings) { + request["max_tokens"] = settings.maxTokens(); + request["temperature"] = settings.temperature(); + + if (settings.useTopP()) + request["top_p"] = settings.topP(); + if (settings.useTopK()) + request["top_k"] = settings.topK(); + if (settings.useFrequencyPenalty()) + request["frequency_penalty"] = settings.frequencyPenalty(); + if (settings.usePresencePenalty()) + request["presence_penalty"] = settings.presencePenalty(); + }; + + if (type == LLMCore::RequestType::CodeCompletion) { + applyModelParams(Settings::codeCompletionSettings()); + } else { + applyModelParams(Settings::chatAssistantSettings()); + } +} + } // namespace QodeAssist::Providers diff --git a/providers/LMStudioProvider.hpp b/providers/LMStudioProvider.hpp index 3c28cdb..1014a00 100644 --- a/providers/LMStudioProvider.hpp +++ b/providers/LMStudioProvider.hpp @@ -26,14 +26,16 @@ namespace QodeAssist::Providers { class LMStudioProvider : public LLMCore::Provider { public: - LMStudioProvider(); - QString name() const override; QString url() const override; QString completionEndpoint() const override; QString chatEndpoint() const override; bool supportsModelListing() const override; - void prepareRequest(QJsonObject &request, LLMCore::RequestType type) override; + void prepareRequest( + QJsonObject &request, + LLMCore::PromptTemplate *prompt, + LLMCore::ContextData context, + LLMCore::RequestType type) override; bool handleResponse(QNetworkReply *reply, QString &accumulatedResponse) override; QList getInstalledModels(const QString &url) override; QList validateRequest(const QJsonObject &request, LLMCore::TemplateType type) override; diff --git a/providers/MistralAIProvider.cpp b/providers/MistralAIProvider.cpp new file mode 100644 index 0000000..201e757 --- /dev/null +++ b/providers/MistralAIProvider.cpp @@ -0,0 +1,215 @@ +#include "MistralAIProvider.hpp" + +#include "settings/ChatAssistantSettings.hpp" +#include "settings/CodeCompletionSettings.hpp" +#include "settings/ProviderSettings.hpp" + +#include +#include +#include +#include +#include + +#include "llmcore/OpenAIMessage.hpp" +#include "llmcore/ValidationUtils.hpp" +#include "logger/Logger.hpp" + +namespace QodeAssist::Providers { + +QString MistralAIProvider::name() const +{ + return "Mistral AI"; +} + +QString MistralAIProvider::url() const +{ + return "https://api.mistral.ai"; +} + +QString MistralAIProvider::completionEndpoint() const +{ + return "/v1/fim/completions"; +} + +QString MistralAIProvider::chatEndpoint() const +{ + return "/v1/chat/completions"; +} + +bool MistralAIProvider::supportsModelListing() const +{ + return true; +} + +bool MistralAIProvider::handleResponse(QNetworkReply *reply, QString &accumulatedResponse) +{ + QByteArray data = reply->readAll(); + if (data.isEmpty()) { + return false; + } + + bool isDone = false; + QByteArrayList lines = data.split('\n'); + + for (const QByteArray &line : lines) { + if (line.trimmed().isEmpty()) { + continue; + } + + if (line == "data: [DONE]") { + isDone = true; + continue; + } + + QByteArray jsonData = line; + if (line.startsWith("data: ")) { + jsonData = line.mid(6); + } + + QJsonParseError error; + QJsonDocument doc = QJsonDocument::fromJson(jsonData, &error); + + if (doc.isNull()) { + continue; + } + + auto message = LLMCore::OpenAIMessage::fromJson(doc.object()); + if (message.hasError()) { + LOG_MESSAGE("Error in OpenAI response: " + message.error); + continue; + } + + QString content = message.getContent(); + if (!content.isEmpty()) { + accumulatedResponse += content; + } + + if (message.isDone()) { + isDone = true; + } + } + + return isDone; +} + +QList MistralAIProvider::getInstalledModels(const QString &url) +{ + QList models; + QNetworkAccessManager manager; + QNetworkRequest request(QString("%1/v1/models").arg(url)); + + request.setHeader(QNetworkRequest::ContentTypeHeader, "application/json"); + if (!apiKey().isEmpty()) { + request.setRawHeader("Authorization", QString("Bearer %1").arg(apiKey()).toUtf8()); + } + + QNetworkReply *reply = manager.get(request); + QEventLoop loop; + QObject::connect(reply, &QNetworkReply::finished, &loop, &QEventLoop::quit); + loop.exec(); + + if (reply->error() == QNetworkReply::NoError) { + QByteArray responseData = reply->readAll(); + QJsonDocument jsonResponse = QJsonDocument::fromJson(responseData); + QJsonObject jsonObject = jsonResponse.object(); + + if (jsonObject.contains("data") && jsonObject["object"].toString() == "list") { + QJsonArray modelArray = jsonObject["data"].toArray(); + for (const QJsonValue &value : modelArray) { + QJsonObject modelObject = value.toObject(); + if (modelObject.contains("id")) { + QString modelId = modelObject["id"].toString(); + models.append(modelId); + } + } + } + } else { + LOG_MESSAGE(QString("Error fetching Mistral AI models: %1").arg(reply->errorString())); + } + + reply->deleteLater(); + return models; +} + +QList MistralAIProvider::validateRequest( + const QJsonObject &request, LLMCore::TemplateType type) +{ + const auto fimReq = QJsonObject{ + {"model", {}}, + {"max_tokens", {}}, + {"stream", {}}, + {"temperature", {}}, + {"prompt", {}}, + {"suffix", {}}}; + + const auto templateReq = QJsonObject{ + {"model", {}}, + {"messages", QJsonArray{{QJsonObject{{"role", {}}, {"content", {}}}}}}, + {"temperature", {}}, + {"max_tokens", {}}, + {"top_p", {}}, + {"frequency_penalty", {}}, + {"presence_penalty", {}}, + {"stop", QJsonArray{}}, + {"stream", {}}}; + + return LLMCore::ValidationUtils::validateRequestFields( + request, type == LLMCore::TemplateType::FIM ? fimReq : templateReq); +} + +QString MistralAIProvider::apiKey() const +{ + return Settings::providerSettings().mistralAiApiKey(); +} + +void MistralAIProvider::prepareNetworkRequest(QNetworkRequest &networkRequest) const +{ + networkRequest.setHeader(QNetworkRequest::ContentTypeHeader, "application/json"); + + if (!apiKey().isEmpty()) { + networkRequest.setRawHeader("Authorization", QString("Bearer %1").arg(apiKey()).toUtf8()); + } +} + +void MistralAIProvider::prepareRequest( + QJsonObject &request, + LLMCore::PromptTemplate *prompt, + LLMCore::ContextData context, + LLMCore::RequestType type) +{ + // if (!isSupportedTemplate(prompt->name())) { + // LOG_MESSAGE(QString("Provider doesn't support %1 template").arg(prompt->name())); + // } + + prompt->prepareRequest(request, context); + + if (type == LLMCore::RequestType::Chat) { + auto &settings = Settings::chatAssistantSettings(); + + request["max_tokens"] = settings.maxTokens(); + request["temperature"] = settings.temperature(); + + if (settings.useTopP()) + request["top_p"] = settings.topP(); + + // request["random_seed"] = ""; + + if (settings.useFrequencyPenalty()) + request["frequency_penalty"] = settings.frequencyPenalty(); + if (settings.usePresencePenalty()) + request["presence_penalty"] = settings.presencePenalty(); + + } else { + auto &settings = Settings::codeCompletionSettings(); + + request["max_tokens"] = settings.maxTokens(); + request["temperature"] = settings.temperature(); + + if (settings.useTopP()) + request["top_p"] = settings.topP(); + + // request["random_seed"] = ""; + } +} + +} // namespace QodeAssist::Providers diff --git a/providers/MistralAIProvider.hpp b/providers/MistralAIProvider.hpp new file mode 100644 index 0000000..5ce3344 --- /dev/null +++ b/providers/MistralAIProvider.hpp @@ -0,0 +1,46 @@ +/* + * Copyright (C) 2024 Petr Mironychev + * + * This file is part of QodeAssist. + * + * QodeAssist is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * QodeAssist is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with QodeAssist. If not, see . + */ + +#pragma once + +#include "llmcore/Provider.hpp" + +namespace QodeAssist::Providers { + +class MistralAIProvider : public LLMCore::Provider +{ +public: + QString name() const override; + QString url() const override; + QString completionEndpoint() const override; + QString chatEndpoint() const override; + bool supportsModelListing() const override; + void prepareRequest( + QJsonObject &request, + LLMCore::PromptTemplate *prompt, + LLMCore::ContextData context, + LLMCore::RequestType type) override; + bool handleResponse(QNetworkReply *reply, QString &accumulatedResponse) override; + QList getInstalledModels(const QString &url) override; + QList validateRequest(const QJsonObject &request, LLMCore::TemplateType type) override; + QString apiKey() const override; + void prepareNetworkRequest(QNetworkRequest &networkRequest) const override; +}; + +} // namespace QodeAssist::Providers diff --git a/providers/OllamaProvider.cpp b/providers/OllamaProvider.cpp index 2c93013..d0cbf7a 100644 --- a/providers/OllamaProvider.cpp +++ b/providers/OllamaProvider.cpp @@ -33,8 +33,6 @@ namespace QodeAssist::Providers { -OllamaProvider::OllamaProvider() {} - QString OllamaProvider::name() const { return "Ollama"; @@ -60,8 +58,18 @@ bool OllamaProvider::supportsModelListing() const return true; } -void OllamaProvider::prepareRequest(QJsonObject &request, LLMCore::RequestType type) +void OllamaProvider::prepareRequest( + QJsonObject &request, + LLMCore::PromptTemplate *prompt, + LLMCore::ContextData context, + LLMCore::RequestType type) { + // if (!isSupportedTemplate(prompt->name())) { + // LOG_MESSAGE(QString("Provider doesn't support %1 template").arg(prompt->name())); + // } + + prompt->prepareRequest(request, context); + auto applySettings = [&request](const auto &settings) { QJsonObject options; options["num_predict"] = settings.maxTokens(); @@ -192,7 +200,7 @@ QList OllamaProvider::validateRequest(const QJsonObject &request, LLMCo {"presence_penalty", {}}}}}; return LLMCore::ValidationUtils::validateRequestFields( - request, type == LLMCore::TemplateType::Fim ? fimReq : messageReq); + request, type == LLMCore::TemplateType::FIM ? fimReq : messageReq); } QString OllamaProvider::apiKey() const @@ -203,6 +211,6 @@ QString OllamaProvider::apiKey() const void OllamaProvider::prepareNetworkRequest(QNetworkRequest &networkRequest) const { networkRequest.setHeader(QNetworkRequest::ContentTypeHeader, "application/json"); -}; +} } // namespace QodeAssist::Providers diff --git a/providers/OllamaProvider.hpp b/providers/OllamaProvider.hpp index f963fb0..38cea88 100644 --- a/providers/OllamaProvider.hpp +++ b/providers/OllamaProvider.hpp @@ -26,14 +26,16 @@ namespace QodeAssist::Providers { class OllamaProvider : public LLMCore::Provider { public: - OllamaProvider(); - QString name() const override; QString url() const override; QString completionEndpoint() const override; QString chatEndpoint() const override; bool supportsModelListing() const override; - void prepareRequest(QJsonObject &request, LLMCore::RequestType type) override; + void prepareRequest( + QJsonObject &request, + LLMCore::PromptTemplate *prompt, + LLMCore::ContextData context, + LLMCore::RequestType type) override; bool handleResponse(QNetworkReply *reply, QString &accumulatedResponse) override; QList getInstalledModels(const QString &url) override; QList validateRequest(const QJsonObject &request, LLMCore::TemplateType type) override; diff --git a/providers/OpenAICompatProvider.cpp b/providers/OpenAICompatProvider.cpp index f515fd4..547cec7 100644 --- a/providers/OpenAICompatProvider.cpp +++ b/providers/OpenAICompatProvider.cpp @@ -34,8 +34,6 @@ namespace QodeAssist::Providers { -OpenAICompatProvider::OpenAICompatProvider() {} - QString OpenAICompatProvider::name() const { return "OpenAI Compatible"; @@ -61,20 +59,17 @@ bool OpenAICompatProvider::supportsModelListing() const return false; } -void OpenAICompatProvider::prepareRequest(QJsonObject &request, LLMCore::RequestType type) +void OpenAICompatProvider::prepareRequest( + QJsonObject &request, + LLMCore::PromptTemplate *prompt, + LLMCore::ContextData context, + LLMCore::RequestType type) { - auto prepareMessages = [](QJsonObject &req) -> QJsonArray { - QJsonArray messages; - if (req.contains("system")) { - messages.append( - QJsonObject{{"role", "system"}, {"content", req.take("system").toString()}}); - } - if (req.contains("prompt")) { - messages.append( - QJsonObject{{"role", "user"}, {"content", req.take("prompt").toString()}}); - } - return messages; - }; + // if (!isSupportedTemplate(prompt->name())) { + // LOG_MESSAGE(QString("Provider doesn't support %1 template").arg(prompt->name())); + // } + + prompt->prepareRequest(request, context); auto applyModelParams = [&request](const auto &settings) { request["max_tokens"] = settings.maxTokens(); @@ -90,11 +85,6 @@ void OpenAICompatProvider::prepareRequest(QJsonObject &request, LLMCore::Request request["presence_penalty"] = settings.presencePenalty(); }; - QJsonArray messages = prepareMessages(request); - if (!messages.isEmpty()) { - request["messages"] = std::move(messages); - } - if (type == LLMCore::RequestType::CodeCompletion) { applyModelParams(Settings::codeCompletionSettings()); } else { diff --git a/providers/OpenAICompatProvider.hpp b/providers/OpenAICompatProvider.hpp index c6abebc..cfcf890 100644 --- a/providers/OpenAICompatProvider.hpp +++ b/providers/OpenAICompatProvider.hpp @@ -26,14 +26,16 @@ namespace QodeAssist::Providers { class OpenAICompatProvider : public LLMCore::Provider { public: - OpenAICompatProvider(); - QString name() const override; QString url() const override; QString completionEndpoint() const override; QString chatEndpoint() const override; bool supportsModelListing() const override; - void prepareRequest(QJsonObject &request, LLMCore::RequestType type) override; + void prepareRequest( + QJsonObject &request, + LLMCore::PromptTemplate *prompt, + LLMCore::ContextData context, + LLMCore::RequestType type) override; bool handleResponse(QNetworkReply *reply, QString &accumulatedResponse) override; QList getInstalledModels(const QString &url) override; QList validateRequest(const QJsonObject &request, LLMCore::TemplateType type) override; diff --git a/providers/OpenAIProvider.cpp b/providers/OpenAIProvider.cpp index c962ae8..81bd7ed 100644 --- a/providers/OpenAIProvider.cpp +++ b/providers/OpenAIProvider.cpp @@ -35,8 +35,6 @@ namespace QodeAssist::Providers { -OpenAIProvider::OpenAIProvider() {} - QString OpenAIProvider::name() const { return "OpenAI"; @@ -62,20 +60,17 @@ bool OpenAIProvider::supportsModelListing() const return true; } -void OpenAIProvider::prepareRequest(QJsonObject &request, LLMCore::RequestType type) +void OpenAIProvider::prepareRequest( + QJsonObject &request, + LLMCore::PromptTemplate *prompt, + LLMCore::ContextData context, + LLMCore::RequestType type) { - auto prepareMessages = [](QJsonObject &req) -> QJsonArray { - QJsonArray messages; - if (req.contains("system")) { - messages.append( - QJsonObject{{"role", "system"}, {"content", req.take("system").toString()}}); - } - if (req.contains("prompt")) { - messages.append( - QJsonObject{{"role", "user"}, {"content", req.take("prompt").toString()}}); - } - return messages; - }; + // if (!isSupportedTemplate(prompt->name())) { + // LOG_MESSAGE(QString("Provider doesn't support %1 template").arg(prompt->name())); + // } + + prompt->prepareRequest(request, context); auto applyModelParams = [&request](const auto &settings) { request["max_tokens"] = settings.maxTokens(); @@ -91,11 +86,6 @@ void OpenAIProvider::prepareRequest(QJsonObject &request, LLMCore::RequestType t request["presence_penalty"] = settings.presencePenalty(); }; - QJsonArray messages = prepareMessages(request); - if (!messages.isEmpty()) { - request["messages"] = std::move(messages); - } - if (type == LLMCore::RequestType::CodeCompletion) { applyModelParams(Settings::codeCompletionSettings()); } else { diff --git a/providers/OpenAIProvider.hpp b/providers/OpenAIProvider.hpp index d47b1e1..5706a83 100644 --- a/providers/OpenAIProvider.hpp +++ b/providers/OpenAIProvider.hpp @@ -26,14 +26,16 @@ namespace QodeAssist::Providers { class OpenAIProvider : public LLMCore::Provider { public: - OpenAIProvider(); - QString name() const override; QString url() const override; QString completionEndpoint() const override; QString chatEndpoint() const override; bool supportsModelListing() const override; - void prepareRequest(QJsonObject &request, LLMCore::RequestType type) override; + void prepareRequest( + QJsonObject &request, + LLMCore::PromptTemplate *prompt, + LLMCore::ContextData context, + LLMCore::RequestType type) override; bool handleResponse(QNetworkReply *reply, QString &accumulatedResponse) override; QList getInstalledModels(const QString &url) override; QList validateRequest(const QJsonObject &request, LLMCore::TemplateType type) override; diff --git a/providers/OpenRouterAIProvider.cpp b/providers/OpenRouterAIProvider.cpp index faa2a88..a6381fd 100644 --- a/providers/OpenRouterAIProvider.cpp +++ b/providers/OpenRouterAIProvider.cpp @@ -33,8 +33,6 @@ namespace QodeAssist::Providers { -OpenRouterProvider::OpenRouterProvider() {} - QString OpenRouterProvider::name() const { return "OpenRouter"; @@ -45,47 +43,6 @@ QString OpenRouterProvider::url() const return "https://openrouter.ai/api"; } -void OpenRouterProvider::prepareRequest(QJsonObject &request, LLMCore::RequestType type) -{ - auto prepareMessages = [](QJsonObject &req) -> QJsonArray { - QJsonArray messages; - if (req.contains("system")) { - messages.append( - QJsonObject{{"role", "system"}, {"content", req.take("system").toString()}}); - } - if (req.contains("prompt")) { - messages.append( - QJsonObject{{"role", "user"}, {"content", req.take("prompt").toString()}}); - } - return messages; - }; - - auto applyModelParams = [&request](const auto &settings) { - request["max_tokens"] = settings.maxTokens(); - request["temperature"] = settings.temperature(); - - if (settings.useTopP()) - request["top_p"] = settings.topP(); - if (settings.useTopK()) - request["top_k"] = settings.topK(); - if (settings.useFrequencyPenalty()) - request["frequency_penalty"] = settings.frequencyPenalty(); - if (settings.usePresencePenalty()) - request["presence_penalty"] = settings.presencePenalty(); - }; - - QJsonArray messages = prepareMessages(request); - if (!messages.isEmpty()) { - request["messages"] = std::move(messages); - } - - if (type == LLMCore::RequestType::CodeCompletion) { - applyModelParams(Settings::codeCompletionSettings()); - } else { - applyModelParams(Settings::chatAssistantSettings()); - } -} - bool OpenRouterProvider::handleResponse(QNetworkReply *reply, QString &accumulatedResponse) { QByteArray data = reply->readAll(); diff --git a/providers/OpenRouterAIProvider.hpp b/providers/OpenRouterAIProvider.hpp index a49e9b7..f7a2c08 100644 --- a/providers/OpenRouterAIProvider.hpp +++ b/providers/OpenRouterAIProvider.hpp @@ -27,11 +27,8 @@ namespace QodeAssist::Providers { class OpenRouterProvider : public OpenAICompatProvider { public: - OpenRouterProvider(); - QString name() const override; QString url() const override; - void prepareRequest(QJsonObject &request, LLMCore::RequestType type) override; bool handleResponse(QNetworkReply *reply, QString &accumulatedResponse) override; QString apiKey() const override; }; diff --git a/providers/Providers.hpp b/providers/Providers.hpp index 811cb1e..81fdd28 100644 --- a/providers/Providers.hpp +++ b/providers/Providers.hpp @@ -22,6 +22,7 @@ #include "llmcore/ProvidersManager.hpp" #include "providers/ClaudeProvider.hpp" #include "providers/LMStudioProvider.hpp" +#include "providers/MistralAIProvider.hpp" #include "providers/OllamaProvider.hpp" #include "providers/OpenAICompatProvider.hpp" #include "providers/OpenAIProvider.hpp" @@ -33,11 +34,12 @@ inline void registerProviders() { auto &providerManager = LLMCore::ProvidersManager::instance(); providerManager.registerProvider(); - providerManager.registerProvider(); - providerManager.registerProvider(); - providerManager.registerProvider(); providerManager.registerProvider(); providerManager.registerProvider(); + providerManager.registerProvider(); + providerManager.registerProvider(); + providerManager.registerProvider(); + providerManager.registerProvider(); } } // namespace QodeAssist::Providers diff --git a/settings/ProviderSettings.cpp b/settings/ProviderSettings.cpp index b217787..db5422a 100644 --- a/settings/ProviderSettings.cpp +++ b/settings/ProviderSettings.cpp @@ -69,6 +69,7 @@ ProviderSettings::ProviderSettings() claudeApiKey.setDefaultValue(""); claudeApiKey.setAutoApply(true); + // OpenAI Settings openAiApiKey.setSettingsKey(Constants::OPEN_AI_API_KEY); openAiApiKey.setLabelText(Tr::tr("OpenAI API Key:")); openAiApiKey.setDisplayStyle(Utils::StringAspect::LineEditDisplay); @@ -77,6 +78,15 @@ ProviderSettings::ProviderSettings() openAiApiKey.setDefaultValue(""); openAiApiKey.setAutoApply(true); + // MistralAI Settings + mistralAiApiKey.setSettingsKey(Constants::MISTRAL_AI_API_KEY); + mistralAiApiKey.setLabelText(Tr::tr("Mistral AI API Key:")); + mistralAiApiKey.setDisplayStyle(Utils::StringAspect::LineEditDisplay); + mistralAiApiKey.setPlaceHolderText(Tr::tr("Enter your API key here")); + mistralAiApiKey.setHistoryCompleter(Constants::MISTRAL_AI_API_KEY_HISTORY); + mistralAiApiKey.setDefaultValue(""); + mistralAiApiKey.setAutoApply(true); + resetToDefaults.m_buttonText = Tr::tr("Reset Page to Defaults"); readSettings(); @@ -96,6 +106,8 @@ ProviderSettings::ProviderSettings() Group{title(Tr::tr("OpenAI Compatible Settings")), Column{openAiCompatApiKey}}, Space{8}, Group{title(Tr::tr("Claude Settings")), Column{claudeApiKey}}, + Space{8}, + Group{title(Tr::tr("Mistral AI Settings")), Column{mistralAiApiKey}}, Stretch{1}}; }); } diff --git a/settings/ProviderSettings.hpp b/settings/ProviderSettings.hpp index e8a0dfd..26cb9b3 100644 --- a/settings/ProviderSettings.hpp +++ b/settings/ProviderSettings.hpp @@ -37,6 +37,7 @@ public: Utils::StringAspect openAiCompatApiKey{this}; Utils::StringAspect claudeApiKey{this}; Utils::StringAspect openAiApiKey{this}; + Utils::StringAspect mistralAiApiKey{this}; private: void setupConnections(); diff --git a/settings/SettingsConstants.hpp b/settings/SettingsConstants.hpp index d93fd70..5166c85 100644 --- a/settings/SettingsConstants.hpp +++ b/settings/SettingsConstants.hpp @@ -100,6 +100,8 @@ const char CLAUDE_API_KEY[] = "QodeAssist.claudeApiKey"; const char CLAUDE_API_KEY_HISTORY[] = "QodeAssist.claudeApiKeyHistory"; const char OPEN_AI_API_KEY[] = "QodeAssist.openAiApiKey"; const char OPEN_AI_API_KEY_HISTORY[] = "QodeAssist.openAiApiKeyHistory"; +const char MISTRAL_AI_API_KEY[] = "QodeAssist.mistralAiApiKey"; +const char MISTRAL_AI_API_KEY_HISTORY[] = "QodeAssist.mistralAiApiKeyHistory"; // context settings const char CC_READ_FULL_FILE[] = "QodeAssist.ccReadFullFile"; diff --git a/templates/Alpaca.hpp b/templates/Alpaca.hpp index abc811b..472c3d8 100644 --- a/templates/Alpaca.hpp +++ b/templates/Alpaca.hpp @@ -29,33 +29,32 @@ class Alpaca : public LLMCore::PromptTemplate public: QString name() const override { return "Alpaca"; } LLMCore::TemplateType type() const override { return LLMCore::TemplateType::Chat; } - QString promptTemplate() const override { return {}; } QStringList stopWords() const override { return QStringList() << "### Instruction:" << "### Response:"; } void prepareRequest(QJsonObject &request, const LLMCore::ContextData &context) const override { - QJsonArray messages = request["messages"].toArray(); + QJsonArray messages; - for (int i = 0; i < messages.size(); ++i) { - QJsonObject message = messages[i].toObject(); - QString role = message["role"].toString(); - QString content = message["content"].toString(); + QString fullContent; - QString formattedContent; - if (role == "system") { - formattedContent = content + "\n\n"; - } else if (role == "user") { - formattedContent = "### Instruction:\n" + content + "\n\n"; - } else if (role == "assistant") { - formattedContent = "### Response:\n" + content + "\n\n"; - } - - message["content"] = formattedContent; - messages[i] = message; + if (context.systemPrompt) { + fullContent += context.systemPrompt.value() + "\n\n"; } + if (context.history) { + for (const auto &msg : context.history.value()) { + if (msg.role == "user") { + fullContent += QString("### Instruction:\n%1\n\n").arg(msg.content); + } else if (msg.role == "assistant") { + fullContent += QString("### Response:\n%1\n\n").arg(msg.content); + } + } + } + + messages.append(QJsonObject{{"role", "user"}, {"content", fullContent}}); + request["messages"] = messages; } QString description() const override diff --git a/templates/ChatML.hpp b/templates/ChatML.hpp index bb125cf..4a1bd02 100644 --- a/templates/ChatML.hpp +++ b/templates/ChatML.hpp @@ -30,23 +30,28 @@ class ChatML : public LLMCore::PromptTemplate public: QString name() const override { return "ChatML"; } LLMCore::TemplateType type() const override { return LLMCore::TemplateType::Chat; } - QString promptTemplate() const override { return {}; } QStringList stopWords() const override { return QStringList() << "<|im_start|>" << "<|im_end|>"; } void prepareRequest(QJsonObject &request, const LLMCore::ContextData &context) const override { - QJsonArray messages = request["messages"].toArray(); + QJsonArray messages; - for (int i = 0; i < messages.size(); ++i) { - QJsonObject message = messages[i].toObject(); - QString role = message["role"].toString(); - QString content = message["content"].toString(); + if (context.systemPrompt) { + messages.append(QJsonObject{ + {"role", "system"}, + {"content", + QString("<|im_start|>system\n%2\n<|im_end|>").arg(context.systemPrompt.value())}}); + } - message["content"] = QString("<|im_start|>%1\n%2\n<|im_end|>").arg(role, content); - - messages[i] = message; + if (context.history) { + for (const auto &msg : context.history.value()) { + messages.append(QJsonObject{ + {"role", msg.role}, + {"content", + QString("<|im_start|>%1\n%2\n<|im_end|>").arg(msg.role, msg.content)}}); + } } request["messages"] = messages; diff --git a/templates/Claude.hpp b/templates/Claude.hpp index 2e66026..6043aa5 100644 --- a/templates/Claude.hpp +++ b/templates/Claude.hpp @@ -30,9 +30,25 @@ class Claude : public LLMCore::PromptTemplate public: LLMCore::TemplateType type() const override { return LLMCore::TemplateType::Chat; } QString name() const override { return "Claude"; } - QString promptTemplate() const override { return {}; } QStringList stopWords() const override { return QStringList(); } - void prepareRequest(QJsonObject &request, const LLMCore::ContextData &context) const override {} + void prepareRequest(QJsonObject &request, const LLMCore::ContextData &context) const override + { + QJsonArray messages; + + if (context.systemPrompt) { + request["system"] = context.systemPrompt.value(); + } + + if (context.history) { + for (const auto &msg : context.history.value()) { + if (msg.role != "system") { + messages.append(QJsonObject{{"role", msg.role}, {"content", msg.content}}); + } + } + } + + request["messages"] = messages; + } QString description() const override { return "Claude"; } }; diff --git a/templates/CodeLlamaFim.hpp b/templates/CodeLlamaFim.hpp index 36f2e2c..63817b7 100644 --- a/templates/CodeLlamaFim.hpp +++ b/templates/CodeLlamaFim.hpp @@ -26,17 +26,17 @@ namespace QodeAssist::Templates { class CodeLlamaFim : public LLMCore::PromptTemplate { public: - LLMCore::TemplateType type() const override { return LLMCore::TemplateType::Fim; } + LLMCore::TemplateType type() const override { return LLMCore::TemplateType::FIM; } QString name() const override { return "CodeLlama FIM"; } - QString promptTemplate() const override { return "
 %1 %2 "; }
     QStringList stopWords() const override
     {
         return QStringList() << "" << "
" << "";
     }
     void prepareRequest(QJsonObject &request, const LLMCore::ContextData &context) const override
     {
-        QString formattedPrompt = promptTemplate().arg(context.prefix, context.suffix);
-        request["prompt"] = formattedPrompt;
+        request["prompt"] = QString("
 %1 %2 ")
+                                .arg(context.prefix.value_or(""), context.suffix.value_or(""));
+        request["system"] = context.systemPrompt.value_or("");
     }
     QString description() const override
     {
diff --git a/templates/CodeLlamaQMLFim.hpp b/templates/CodeLlamaQMLFim.hpp
index 1be20eb..df827be 100644
--- a/templates/CodeLlamaQMLFim.hpp
+++ b/templates/CodeLlamaQMLFim.hpp
@@ -26,9 +26,8 @@ namespace QodeAssist::Templates {
 class CodeLlamaQMLFim : public LLMCore::PromptTemplate
 {
 public:
-    LLMCore::TemplateType type() const override { return LLMCore::TemplateType::Fim; }
+    LLMCore::TemplateType type() const override { return LLMCore::TemplateType::FIM; }
     QString name() const override { return "CodeLlama QML FIM"; }
-    QString promptTemplate() const override { return "%1
%2"; }
     QStringList stopWords() const override
     {
         return QStringList() << "" << "
" << "
" << "
" << "< EOT >" << "\\end" @@ -36,8 +35,9 @@ public: } void prepareRequest(QJsonObject &request, const LLMCore::ContextData &context) const override { - QString formattedPrompt = promptTemplate().arg(context.suffix, context.prefix); - request["prompt"] = formattedPrompt; + request["prompt"] = QString("%1
%2")
+                                .arg(context.suffix.value_or(""), context.prefix.value_or(""));
+        request["system"] = context.systemPrompt.value_or("");
     }
     QString description() const override
     {
diff --git a/templates/Llama2.hpp b/templates/Llama2.hpp
index 6888f7b..ffcdf3c 100644
--- a/templates/Llama2.hpp
+++ b/templates/Llama2.hpp
@@ -29,30 +29,30 @@ class Llama2 : public LLMCore::PromptTemplate
 public:
     QString name() const override { return "Llama 2"; }
     LLMCore::TemplateType type() const override { return LLMCore::TemplateType::Chat; }
-    QString promptTemplate() const override { return {}; }
     QStringList stopWords() const override { return QStringList() << "[INST]"; }
     void prepareRequest(QJsonObject &request, const LLMCore::ContextData &context) const override
     {
-        QJsonArray messages = request["messages"].toArray();
+        QJsonArray messages;
 
-        for (int i = 0; i < messages.size(); ++i) {
-            QJsonObject message = messages[i].toObject();
-            QString role = message["role"].toString();
-            QString content = message["content"].toString();
+        QString fullContent;
 
-            QString formattedContent;
-            if (role == "system") {
-                formattedContent = QString("[INST]<>\n%1\n<>[/INST]\n").arg(content);
-            } else if (role == "user") {
-                formattedContent = QString("[INST]%1[/INST]\n").arg(content);
-            } else if (role == "assistant") {
-                formattedContent = content + "\n";
-            }
-
-            message["content"] = formattedContent;
-            messages[i] = message;
+        if (context.systemPrompt) {
+            fullContent
+                += QString("[INST]<>\n%1\n<>[/INST]\n").arg(context.systemPrompt.value());
         }
 
+        if (context.history) {
+            for (const auto &msg : context.history.value()) {
+                if (msg.role == "user") {
+                    fullContent += QString("[INST]%1[/INST]\n").arg(msg.content);
+                } else if (msg.role == "assistant") {
+                    fullContent += msg.content + "\n";
+                }
+            }
+        }
+
+        messages.append(QJsonObject{{"role", "user"}, {"content", fullContent}});
+
         request["messages"] = messages;
     }
     QString description() const override
diff --git a/templates/Llama3.hpp b/templates/Llama3.hpp
index 7b3bf2b..26c5f68 100644
--- a/templates/Llama3.hpp
+++ b/templates/Llama3.hpp
@@ -30,24 +30,30 @@ class Llama3 : public LLMCore::PromptTemplate
 public:
     QString name() const override { return "Llama 3"; }
     LLMCore::TemplateType type() const override { return LLMCore::TemplateType::Chat; }
-    QString promptTemplate() const override { return ""; }
     QStringList stopWords() const override
     {
         return QStringList() << "<|start_header_id|>" << "<|end_header_id|>" << "<|eot_id|>";
     }
     void prepareRequest(QJsonObject &request, const LLMCore::ContextData &context) const override
     {
-        QJsonArray messages = request["messages"].toArray();
+        QJsonArray messages;
 
-        for (int i = 0; i < messages.size(); ++i) {
-            QJsonObject message = messages[i].toObject();
-            QString role = message["role"].toString();
-            QString content = message["content"].toString();
+        if (context.systemPrompt) {
+            messages.append(QJsonObject{
+                {"role", "system"},
+                {"content",
+                 QString("<|start_header_id|>system<|end_header_id|>%2<|eot_id|>")
+                     .arg(context.systemPrompt.value())}});
+        }
 
-            message["content"]
-                = QString("<|start_header_id|>%1<|end_header_id|>%2<|eot_id|>").arg(role, content);
-
-            messages[i] = message;
+        if (context.history) {
+            for (const auto &msg : context.history.value()) {
+                messages.append(QJsonObject{
+                    {"role", msg.role},
+                    {"content",
+                     QString("<|start_header_id|>%1<|end_header_id|>%2<|eot_id|>")
+                         .arg(msg.role, msg.content)}});
+            }
         }
 
         request["messages"] = messages;
diff --git a/templates/MistralAI.hpp b/templates/MistralAI.hpp
new file mode 100644
index 0000000..ddada9b
--- /dev/null
+++ b/templates/MistralAI.hpp
@@ -0,0 +1,69 @@
+/* 
+ * Copyright (C) 2024 Petr Mironychev
+ *
+ * This file is part of QodeAssist.
+ *
+ * QodeAssist is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * QodeAssist is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with QodeAssist. If not, see .
+ */
+
+#pragma once
+
+#include 
+
+#include "llmcore/PromptTemplate.hpp"
+
+namespace QodeAssist::Templates {
+
+class MistralAIFim : public LLMCore::PromptTemplate
+{
+public:
+    LLMCore::TemplateType type() const override { return LLMCore::TemplateType::FIM; }
+    QString name() const override { return "Mistral AI FIM"; }
+    QStringList stopWords() const override { return QStringList(); }
+    void prepareRequest(QJsonObject &request, const LLMCore::ContextData &context) const override
+    {
+        request["prompt"] = context.prefix.value_or("");
+        request["suffix"] = context.suffix.value_or("");
+    }
+    QString description() const override { return "template will take from ollama modelfile"; }
+};
+
+class MistralAIChat : public LLMCore::PromptTemplate
+{
+public:
+    LLMCore::TemplateType type() const override { return LLMCore::TemplateType::Chat; }
+    QString name() const override { return "Mistral AI Chat"; }
+    QStringList stopWords() const override { return QStringList(); }
+
+    void prepareRequest(QJsonObject &request, const LLMCore::ContextData &context) const override
+    {
+        QJsonArray messages;
+
+        if (context.systemPrompt) {
+            messages.append(
+                QJsonObject{{"role", "system"}, {"content", context.systemPrompt.value()}});
+        }
+
+        if (context.history) {
+            for (const auto &msg : context.history.value()) {
+                messages.append(QJsonObject{{"role", msg.role}, {"content", msg.content}});
+            }
+        }
+
+        request["messages"] = messages;
+    }
+    QString description() const override { return "template will take from ollama modelfile"; }
+};
+
+} // namespace QodeAssist::Templates
diff --git a/templates/Ollama.hpp b/templates/Ollama.hpp
index 2e23664..9f19e1f 100644
--- a/templates/Ollama.hpp
+++ b/templates/Ollama.hpp
@@ -25,31 +25,44 @@
 
 namespace QodeAssist::Templates {
 
-class OllamaAutoFim : public LLMCore::PromptTemplate
+class OllamaFim : public LLMCore::PromptTemplate
 {
 public:
-    LLMCore::TemplateType type() const override { return LLMCore::TemplateType::Fim; }
-    QString name() const override { return "Ollama Auto FIM"; }
-    QString promptTemplate() const override { return {}; }
+    LLMCore::TemplateType type() const override { return LLMCore::TemplateType::FIM; }
+    QString name() const override { return "Ollama FIM"; }
     QStringList stopWords() const override { return QStringList(); }
     void prepareRequest(QJsonObject &request, const LLMCore::ContextData &context) const override
     {
-        request["prompt"] = context.prefix;
-        request["suffix"] = context.suffix;
+        request["prompt"] = context.prefix.value_or("");
+        request["suffix"] = context.suffix.value_or("");
+        request["system"] = context.systemPrompt.value_or("");
     }
     QString description() const override { return "template will take from ollama modelfile"; }
 };
 
-class OllamaAutoChat : public LLMCore::PromptTemplate
+class OllamaChat : public LLMCore::PromptTemplate
 {
 public:
     LLMCore::TemplateType type() const override { return LLMCore::TemplateType::Chat; }
-    QString name() const override { return "Ollama Auto Chat"; }
-    QString promptTemplate() const override { return {}; }
+    QString name() const override { return "Ollama Chat"; }
     QStringList stopWords() const override { return QStringList(); }
 
     void prepareRequest(QJsonObject &request, const LLMCore::ContextData &context) const override
     {
+        QJsonArray messages;
+
+        if (context.systemPrompt) {
+            messages.append(
+                QJsonObject{{"role", "system"}, {"content", context.systemPrompt.value()}});
+        }
+
+        if (context.history) {
+            for (const auto &msg : context.history.value()) {
+                messages.append(QJsonObject{{"role", msg.role}, {"content", msg.content}});
+            }
+        }
+
+        request["messages"] = messages;
     }
     QString description() const override { return "template will take from ollama modelfile"; }
 };
diff --git a/templates/OpenAI.hpp b/templates/OpenAI.hpp
index e5bc467..491f6e0 100644
--- a/templates/OpenAI.hpp
+++ b/templates/OpenAI.hpp
@@ -30,9 +30,24 @@ class OpenAI : public LLMCore::PromptTemplate
 public:
     LLMCore::TemplateType type() const override { return LLMCore::TemplateType::Chat; }
     QString name() const override { return "OpenAI"; }
-    QString promptTemplate() const override { return {}; }
     QStringList stopWords() const override { return QStringList(); }
-    void prepareRequest(QJsonObject &request, const LLMCore::ContextData &context) const override {}
+    void prepareRequest(QJsonObject &request, const LLMCore::ContextData &context) const override
+    {
+        QJsonArray messages;
+
+        if (context.systemPrompt) {
+            messages.append(
+                QJsonObject{{"role", "system"}, {"content", context.systemPrompt.value()}});
+        }
+
+        if (context.history) {
+            for (const auto &msg : context.history.value()) {
+                messages.append(QJsonObject{{"role", msg.role}, {"content", msg.content}});
+            }
+        }
+
+        request["messages"] = messages;
+    }
     QString description() const override { return "OpenAI"; }
 };
 
diff --git a/templates/BasicChat.hpp b/templates/OpenAICompatible.hpp
similarity index 67%
rename from templates/BasicChat.hpp
rename to templates/OpenAICompatible.hpp
index 7e552b8..d086faa 100644
--- a/templates/BasicChat.hpp
+++ b/templates/OpenAICompatible.hpp
@@ -25,15 +25,29 @@
 
 namespace QodeAssist::Templates {
 
-class BasicChat : public LLMCore::PromptTemplate
+class OpenAICompatible : public LLMCore::PromptTemplate
 {
 public:
     LLMCore::TemplateType type() const override { return LLMCore::TemplateType::Chat; }
-    QString name() const override { return "Basic Chat"; }
-    QString promptTemplate() const override { return {}; }
+    QString name() const override { return "OpenAI Compatible"; }
     QStringList stopWords() const override { return QStringList(); }
     void prepareRequest(QJsonObject &request, const LLMCore::ContextData &context) const override
-    {}
+    {
+        QJsonArray messages;
+
+        if (context.systemPrompt) {
+            messages.append(
+                QJsonObject{{"role", "system"}, {"content", context.systemPrompt.value()}});
+        }
+
+        if (context.history) {
+            for (const auto &msg : context.history.value()) {
+                messages.append(QJsonObject{{"role", msg.role}, {"content", msg.content}});
+            }
+        }
+
+        request["messages"] = messages;
+    }
     QString description() const override { return "chat without tokens"; }
 };
 
diff --git a/templates/Qwen.hpp b/templates/Qwen.hpp
index 1231a74..caeae74 100644
--- a/templates/Qwen.hpp
+++ b/templates/Qwen.hpp
@@ -28,16 +28,13 @@ class QwenFim : public LLMCore::PromptTemplate
 {
 public:
     QString name() const override { return "Qwen FIM"; }
-    LLMCore::TemplateType type() const override { return LLMCore::TemplateType::Fim; }
-    QString promptTemplate() const override
-    {
-        return "<|fim_prefix|>%1<|fim_suffix|>%2<|fim_middle|>";
-    }
+    LLMCore::TemplateType type() const override { return LLMCore::TemplateType::FIM; }
     QStringList stopWords() const override { return QStringList() << "<|endoftext|>" << "<|EOT|>"; }
     void prepareRequest(QJsonObject &request, const LLMCore::ContextData &context) const override
     {
-        QString formattedPrompt = promptTemplate().arg(context.prefix, context.suffix);
-        request["prompt"] = formattedPrompt;
+        request["prompt"] = QString("<|fim_prefix|>%1<|fim_suffix|>%2<|fim_middle|>")
+                                .arg(context.prefix.value_or(""), context.suffix.value_or(""));
+        request["system"] = context.systemPrompt.value_or("");
     }
     QString description() const override
     {
diff --git a/templates/StarCoder2Fim.hpp b/templates/StarCoder2Fim.hpp
index 00f68da..c3f535a 100644
--- a/templates/StarCoder2Fim.hpp
+++ b/templates/StarCoder2Fim.hpp
@@ -26,9 +26,8 @@ namespace QodeAssist::Templates {
 class StarCoder2Fim : public LLMCore::PromptTemplate
 {
 public:
-    LLMCore::TemplateType type() const override { return LLMCore::TemplateType::Fim; }
+    LLMCore::TemplateType type() const override { return LLMCore::TemplateType::FIM; }
     QString name() const override { return "StarCoder2 FIM"; }
-    QString promptTemplate() const override { return "%1%2"; }
     QStringList stopWords() const override
     {
         return QStringList() << "<|endoftext|>" << "" << "" << ""
@@ -36,8 +35,9 @@ public:
     }
     void prepareRequest(QJsonObject &request, const LLMCore::ContextData &context) const override
     {
-        QString formattedPrompt = promptTemplate().arg(context.prefix, context.suffix);
-        request["prompt"] = formattedPrompt;
+        request["prompt"] = QString("%1%2")
+                                .arg(context.prefix.value_or(""), context.suffix.value_or(""));
+        request["system"] = context.systemPrompt.value_or("");
     }
     QString description() const override
     {
diff --git a/templates/Templates.hpp b/templates/Templates.hpp
index 872b58d..38c7255 100644
--- a/templates/Templates.hpp
+++ b/templates/Templates.hpp
@@ -21,17 +21,18 @@
 
 #include "llmcore/PromptTemplateManager.hpp"
 #include "templates/Alpaca.hpp"
-#include "templates/BasicChat.hpp"
 #include "templates/ChatML.hpp"
 #include "templates/Claude.hpp"
 #include "templates/CodeLlamaFim.hpp"
 #include "templates/CodeLlamaQMLFim.hpp"
-#include "templates/CustomFimTemplate.hpp"
-#include "templates/DeepSeekCoderFim.hpp"
-#include "templates/Llama2.hpp"
-#include "templates/Llama3.hpp"
+#include "templates/MistralAI.hpp"
 #include "templates/Ollama.hpp"
 #include "templates/OpenAI.hpp"
+#include "templates/OpenAICompatible.hpp"
+// #include "templates/CustomFimTemplate.hpp"
+// #include "templates/DeepSeekCoderFim.hpp"
+#include "templates/Llama2.hpp"
+#include "templates/Llama3.hpp"
 #include "templates/Qwen.hpp"
 #include "templates/StarCoder2Fim.hpp"
 
@@ -41,20 +42,22 @@ inline void registerTemplates()
 {
     auto &templateManager = LLMCore::PromptTemplateManager::instance();
     templateManager.registerTemplate();
-    templateManager.registerTemplate();
-    templateManager.registerTemplate();
-    templateManager.registerTemplate();
-    templateManager.registerTemplate();
-    templateManager.registerTemplate();
-    templateManager.registerTemplate();
-    templateManager.registerTemplate();
-    templateManager.registerTemplate();
-    templateManager.registerTemplate();
-    templateManager.registerTemplate();
-    templateManager.registerTemplate();
+    templateManager.registerTemplate();
+    templateManager.registerTemplate();
     templateManager.registerTemplate();
     templateManager.registerTemplate();
+    templateManager.registerTemplate();
+    templateManager.registerTemplate();
     templateManager.registerTemplate();
+    templateManager.registerTemplate();
+    templateManager.registerTemplate();
+    templateManager.registerTemplate();
+    templateManager.registerTemplate();
+    // templateManager.registerTemplate();
+    // templateManager.registerTemplate();
+    templateManager.registerTemplate();
+    templateManager.registerTemplate();
+    templateManager.registerTemplate();
 }
 
 } // namespace QodeAssist::Templates