refactor: Rework providers and templates logic

This commit is contained in:
Petr Mironychev 2025-02-22 19:39:28 +01:00 committed by GitHub
parent bd25736a55
commit d96f44d42c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
44 changed files with 701 additions and 524 deletions

View File

@ -43,26 +43,28 @@ add_qtc_plugin(QodeAssist
LLMClientInterface.hpp LLMClientInterface.cpp
templates/Templates.hpp
templates/CodeLlamaFim.hpp
templates/StarCoder2Fim.hpp
templates/DeepSeekCoderFim.hpp
templates/CustomFimTemplate.hpp
templates/Qwen.hpp
templates/Ollama.hpp
templates/BasicChat.hpp
templates/Claude.hpp
templates/OpenAI.hpp
templates/MistralAI.hpp
templates/StarCoder2Fim.hpp
# templates/DeepSeekCoderFim.hpp
# templates/CustomFimTemplate.hpp
templates/Qwen.hpp
templates/OpenAICompatible.hpp
templates/Llama3.hpp
templates/ChatML.hpp
templates/Alpaca.hpp
templates/Llama2.hpp
templates/Claude.hpp
templates/OpenAI.hpp
templates/CodeLlamaQMLFim.hpp
providers/Providers.hpp
providers/OllamaProvider.hpp providers/OllamaProvider.cpp
providers/ClaudeProvider.hpp providers/ClaudeProvider.cpp
providers/OpenAIProvider.hpp providers/OpenAIProvider.cpp
providers/MistralAIProvider.hpp providers/MistralAIProvider.cpp
providers/LMStudioProvider.hpp providers/LMStudioProvider.cpp
providers/OpenAICompatProvider.hpp providers/OpenAICompatProvider.cpp
providers/OpenRouterAIProvider.hpp providers/OpenRouterAIProvider.cpp
providers/ClaudeProvider.hpp providers/ClaudeProvider.cpp
providers/OpenAIProvider.hpp providers/OpenAIProvider.cpp
QodeAssist.qrc
LSPCompletion.hpp
LLMSuggestion.hpp LLMSuggestion.cpp

View File

@ -93,51 +93,35 @@ void ClientInterface::sendMessage(
}
LLMCore::ContextData context;
context.prefix = message;
context.suffix = "";
QString systemPrompt;
if (chatAssistantSettings.useSystemPrompt())
systemPrompt = chatAssistantSettings.systemPrompt();
if (!linkedFiles.isEmpty()) {
systemPrompt = getSystemPromptWithLinkedFiles(systemPrompt, linkedFiles);
if (chatAssistantSettings.useSystemPrompt()) {
QString systemPrompt = chatAssistantSettings.systemPrompt();
if (!linkedFiles.isEmpty()) {
systemPrompt = getSystemPromptWithLinkedFiles(systemPrompt, linkedFiles);
}
context.systemPrompt = systemPrompt;
}
QJsonObject providerRequest;
providerRequest["model"] = Settings::generalSettings().caModel();
providerRequest["stream"] = chatAssistantSettings.stream();
providerRequest["messages"] = m_chatModel->prepareMessagesForRequest(systemPrompt);
if (promptTemplate)
promptTemplate->prepareRequest(providerRequest, context);
else
qWarning("No prompt template found");
if (provider)
provider->prepareRequest(providerRequest, LLMCore::RequestType::Chat);
else
qWarning("No provider found");
QVector<LLMCore::Message> messages;
for (const auto &msg : m_chatModel->getChatHistory()) {
messages.append({msg.role == ChatModel::ChatRole::User ? "user" : "assistant", msg.content});
}
context.history = messages;
LLMCore::LLMConfig config;
config.requestType = LLMCore::RequestType::Chat;
config.provider = provider;
config.promptTemplate = promptTemplate;
config.url = QString("%1%2").arg(Settings::generalSettings().caUrl(), provider->chatEndpoint());
config.providerRequest = providerRequest;
config.multiLineCompletion = false;
config.apiKey = provider->apiKey();
config.providerRequest
= {{"model", Settings::generalSettings().caModel()},
{"stream", chatAssistantSettings.stream()}};
QJsonObject request;
request["id"] = QUuid::createUuid().toString();
auto errors = config.provider->validateRequest(config.providerRequest, promptTemplate->type());
if (!errors.isEmpty()) {
LOG_MESSAGE("Validate errors for chat request:");
LOG_MESSAGES(errors);
return;
}
config.provider
->prepareRequest(config.providerRequest, promptTemplate, context, LLMCore::RequestType::Chat);
QJsonObject request{{"id", QUuid::createUuid().toString()}};
m_requestHandler->sendLLMRequest(config, request);
}

View File

@ -28,7 +28,6 @@
#include "CodeHandler.hpp"
#include "context/DocumentContextReader.hpp"
#include "llmcore/MessageBuilder.hpp"
#include "llmcore/PromptTemplateManager.hpp"
#include "llmcore/ProvidersManager.hpp"
#include "logger/Logger.hpp"
@ -159,7 +158,7 @@ bool QodeAssist::LLMClientInterface::isSpecifyCompletion(const QJsonObject &requ
void LLMClientInterface::handleCompletion(const QJsonObject &request)
{
const auto updatedContext = prepareContext(request);
auto updatedContext = prepareContext(request);
auto &completeSettings = Settings::codeCompletionSettings();
auto &generalSettings = Settings::generalSettings();
@ -196,7 +195,7 @@ void LLMClientInterface::handleCompletion(const QJsonObject &request)
config.promptTemplate = promptTemplate;
config.url = QUrl(QString("%1%2").arg(
url,
promptTemplate->type() == LLMCore::TemplateType::Fim ? provider->completionEndpoint()
promptTemplate->type() == LLMCore::TemplateType::FIM ? provider->completionEndpoint()
: provider->chatEndpoint()));
config.apiKey = provider->apiKey();
@ -211,29 +210,30 @@ void LLMClientInterface::handleCompletion(const QJsonObject &request)
QString systemPrompt;
if (completeSettings.useSystemPrompt())
systemPrompt.append(completeSettings.systemPrompt());
if (!updatedContext.fileContext.isEmpty())
systemPrompt.append(updatedContext.fileContext);
if (updatedContext.fileContext.has_value())
systemPrompt.append(updatedContext.fileContext.value());
QString userMessage;
if (completeSettings.useUserMessageTemplateForCC()
&& promptTemplate->type() == LLMCore::TemplateType::Chat) {
userMessage
= completeSettings.processMessageToFIM(updatedContext.prefix, updatedContext.suffix);
} else {
userMessage = updatedContext.prefix;
updatedContext.systemPrompt = systemPrompt;
if (promptTemplate->type() == LLMCore::TemplateType::Chat) {
QString userMessage;
if (completeSettings.useUserMessageTemplateForCC()) {
userMessage = completeSettings.processMessageToFIM(
updatedContext.prefix.value_or(""), updatedContext.suffix.value_or(""));
} else {
userMessage = updatedContext.prefix.value_or("") + updatedContext.suffix.value_or("");
}
QVector<LLMCore::Message> messages;
messages.append({"user", userMessage});
updatedContext.history = messages;
}
auto message = LLMCore::MessageBuilder()
.addSystemMessage(systemPrompt)
.addUserMessage(userMessage)
.addSuffix(updatedContext.suffix)
.addTokenizer(promptTemplate);
message.saveTo(
config.provider->prepareRequest(
config.providerRequest,
providerName == "Ollama" ? LLMCore::ProvidersApi::Ollama : LLMCore::ProvidersApi::OpenAI);
config.provider->prepareRequest(config.providerRequest, LLMCore::RequestType::CodeCompletion);
promptTemplate,
updatedContext,
LLMCore::RequestType::CodeCompletion);
auto errors = config.provider->validateRequest(config.providerRequest, promptTemplate->type());
if (!errors.isEmpty()) {

View File

@ -216,7 +216,7 @@ LLMCore::ContextData DocumentContextReader::prepareContext(int lineNumber, int c
fileContext.append("\n ").append(
ChangesManager::instance().getRecentChangesContext(m_textDocument));
return {contextBefore, contextAfter, fileContext};
return {.prefix = contextBefore, .suffix = contextAfter, .fileContext = fileContext};
}
QString DocumentContextReader::getContextBefore(int lineNumber, int cursorPosition) const

View File

@ -10,7 +10,6 @@ add_library(LLMCore STATIC
OllamaMessage.hpp OllamaMessage.cpp
OpenAIMessage.hpp OpenAIMessage.cpp
ValidationUtils.hpp ValidationUtils.cpp
MessageBuilder.hpp MessageBuilder.cpp
)
target_link_libraries(LLMCore

View File

@ -20,14 +20,23 @@
#pragma once
#include <QString>
#include <QVector>
namespace QodeAssist::LLMCore {
struct Message
{
QString role;
QString content;
};
struct ContextData
{
QString prefix;
QString suffix;
QString fileContext;
std::optional<QString> systemPrompt;
std::optional<QString> prefix;
std::optional<QString> suffix;
std::optional<QString> fileContext;
std::optional<QVector<Message>> history;
};
} // namespace QodeAssist::LLMCore

View File

@ -1,93 +0,0 @@
/*
* Copyright (C) 2024 Petr Mironychev
*
* This file is part of QodeAssist.
*
* QodeAssist is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* QodeAssist is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with QodeAssist. If not, see <https://www.gnu.org/licenses/>.
*/
#include "MessageBuilder.hpp"
QodeAssist::LLMCore::MessageBuilder &QodeAssist::LLMCore::MessageBuilder::addSystemMessage(
const QString &content)
{
m_systemMessage = content;
return *this;
}
QodeAssist::LLMCore::MessageBuilder &QodeAssist::LLMCore::MessageBuilder::addUserMessage(
const QString &content)
{
m_messages.append({MessageRole::User, content});
return *this;
}
QodeAssist::LLMCore::MessageBuilder &QodeAssist::LLMCore::MessageBuilder::addSuffix(
const QString &content)
{
m_suffix = content;
return *this;
}
QodeAssist::LLMCore::MessageBuilder &QodeAssist::LLMCore::MessageBuilder::addTokenizer(
PromptTemplate *promptTemplate)
{
m_promptTemplate = promptTemplate;
return *this;
}
QString QodeAssist::LLMCore::MessageBuilder::roleToString(MessageRole role) const
{
switch (role) {
case MessageRole::System:
return ROLE_SYSTEM;
case MessageRole::User:
return ROLE_USER;
case MessageRole::Assistant:
return ROLE_ASSISTANT;
default:
return ROLE_USER;
}
}
void QodeAssist::LLMCore::MessageBuilder::saveTo(QJsonObject &request, ProvidersApi api)
{
if (!m_promptTemplate) {
return;
}
ContextData context{
m_messages.isEmpty() ? QString() : m_messages.last().content, m_suffix, m_systemMessage};
if (api == ProvidersApi::Ollama) {
if (m_promptTemplate->type() == TemplateType::Fim) {
request["system"] = m_systemMessage;
m_promptTemplate->prepareRequest(request, context);
} else {
QJsonArray messages;
messages.append(QJsonObject{{"role", "system"}, {"content", m_systemMessage}});
messages.append(QJsonObject{{"role", "user"}, {"content", m_messages.last().content}});
request["messages"] = messages;
m_promptTemplate->prepareRequest(request, context);
}
} else if (api == ProvidersApi::OpenAI) {
QJsonArray messages;
messages.append(QJsonObject{{"role", "system"}, {"content", m_systemMessage}});
messages.append(QJsonObject{{"role", "user"}, {"content", m_messages.last().content}});
request["messages"] = messages;
m_promptTemplate->prepareRequest(request, context);
}
}

View File

@ -1,68 +0,0 @@
/*
* Copyright (C) 2024 Petr Mironychev
*
* This file is part of QodeAssist.
*
* QodeAssist is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* QodeAssist is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with QodeAssist. If not, see <https://www.gnu.org/licenses/>.
*/
#pragma once
#include <QJsonArray>
#include <QJsonObject>
#include <QString>
#include <QVector>
#include "PromptTemplate.hpp"
namespace QodeAssist::LLMCore {
enum class MessageRole { System, User, Assistant };
enum class OllamaFormat { Messages, Completions };
enum class ProvidersApi { Ollama, OpenAI, Claude };
static const QString ROLE_SYSTEM = "system";
static const QString ROLE_USER = "user";
static const QString ROLE_ASSISTANT = "assistant";
struct Message
{
MessageRole role;
QString content;
};
class MessageBuilder
{
public:
MessageBuilder &addSystemMessage(const QString &content);
MessageBuilder &addUserMessage(const QString &content);
MessageBuilder &addSuffix(const QString &content);
MessageBuilder &addTokenizer(PromptTemplate *promptTemplate);
QString roleToString(MessageRole role) const;
void saveTo(QJsonObject &request, ProvidersApi api);
private:
QString m_systemMessage;
QString m_suffix;
QVector<Message> m_messages;
PromptTemplate *m_promptTemplate;
};
} // namespace QodeAssist::LLMCore

View File

@ -27,7 +27,7 @@
namespace QodeAssist::LLMCore {
enum class TemplateType { Chat, Fim };
enum class TemplateType { Chat, FIM };
class PromptTemplate
{
@ -35,7 +35,7 @@ public:
virtual ~PromptTemplate() = default;
virtual TemplateType type() const = 0;
virtual QString name() const = 0;
virtual QString promptTemplate() const = 0;
// virtual QString promptTemplate() const = 0;
virtual QStringList stopWords() const = 0;
virtual void prepareRequest(QJsonObject &request, const ContextData &context) const = 0;
virtual QString description() const = 0;

View File

@ -23,6 +23,7 @@
#include <QNetworkRequest>
#include <QString>
#include "ContextData.hpp"
#include "PromptTemplate.hpp"
#include "RequestType.hpp"
@ -42,7 +43,12 @@ public:
virtual QString chatEndpoint() const = 0;
virtual bool supportsModelListing() const = 0;
virtual void prepareRequest(QJsonObject &request, RequestType type) = 0;
virtual void prepareRequest(
QJsonObject &request,
LLMCore::PromptTemplate *prompt,
LLMCore::ContextData context,
LLMCore::RequestType type)
= 0;
virtual bool handleResponse(QNetworkReply *reply, QString &accumulatedResponse) = 0;
virtual QList<QString> getInstalledModels(const QString &url) = 0;
virtual QList<QString> validateRequest(const QJsonObject &request, TemplateType type) = 0;

View File

@ -58,7 +58,10 @@ void RequestHandler::sendLLMRequest(const LLMConfig &config, const QJsonObject &
reply->deleteLater();
m_activeRequests.remove(requestId);
if (reply->error() != QNetworkReply::NoError) {
LOG_MESSAGE(QString("Error in QodeAssist request: %1").arg(reply->errorString()));
LOG_MESSAGE(QString("Error details: %1\nStatus code: %2\nResponse: %3")
.arg(reply->errorString())
.arg(reply->attribute(QNetworkRequest::HttpStatusCodeAttribute).toInt())
.arg(QString(reply->readAll())));
emit requestFinished(requestId, false, reply->errorString());
} else {
LOG_MESSAGE("Request finished successfully");

View File

@ -21,5 +21,5 @@
namespace QodeAssist::LLMCore {
enum RequestType { CodeCompletion, Chat };
enum RequestType { CodeCompletion, Chat, Embedding };
}

View File

@ -30,13 +30,10 @@
#include "logger/Logger.hpp"
#include "settings/ChatAssistantSettings.hpp"
#include "settings/CodeCompletionSettings.hpp"
#include "settings/GeneralSettings.hpp"
#include "settings/ProviderSettings.hpp"
namespace QodeAssist::Providers {
ClaudeProvider::ClaudeProvider() {}
QString ClaudeProvider::name() const
{
return "Claude";
@ -62,31 +59,17 @@ bool ClaudeProvider::supportsModelListing() const
return true;
}
void ClaudeProvider::prepareRequest(QJsonObject &request, LLMCore::RequestType type)
void ClaudeProvider::prepareRequest(
QJsonObject &request,
LLMCore::PromptTemplate *prompt,
LLMCore::ContextData context,
LLMCore::RequestType type)
{
auto prepareMessages = [](QJsonObject &req) -> QJsonArray {
QJsonArray messages;
if (req.contains("messages")) {
QJsonArray origMessages = req["messages"].toArray();
for (const auto &msg : origMessages) {
QJsonObject message = msg.toObject();
if (message["role"].toString() == "system") {
req["system"] = message["content"];
} else {
messages.append(message);
}
}
} else {
if (req.contains("system")) {
req["system"] = req["system"].toString();
}
if (req.contains("prompt")) {
messages.append(
QJsonObject{{"role", "user"}, {"content", req.take("prompt").toString()}});
}
}
return messages;
};
// if (!isSupportedTemplate(prompt->name())) {
// LOG_MESSAGE(QString("Provider doesn't support %1 template").arg(prompt->name()));
// }
prompt->prepareRequest(request, context);
auto applyModelParams = [&request](const auto &settings) {
request["max_tokens"] = settings.maxTokens();
@ -98,11 +81,6 @@ void ClaudeProvider::prepareRequest(QJsonObject &request, LLMCore::RequestType t
request["stream"] = true;
};
QJsonArray messages = prepareMessages(request);
if (!messages.isEmpty()) {
request["messages"] = std::move(messages);
}
if (type == LLMCore::RequestType::CodeCompletion) {
applyModelParams(Settings::codeCompletionSettings());
} else {

View File

@ -26,14 +26,16 @@ namespace QodeAssist::Providers {
class ClaudeProvider : public LLMCore::Provider
{
public:
ClaudeProvider();
QString name() const override;
QString url() const override;
QString completionEndpoint() const override;
QString chatEndpoint() const override;
bool supportsModelListing() const override;
void prepareRequest(QJsonObject &request, LLMCore::RequestType type) override;
void prepareRequest(
QJsonObject &request,
LLMCore::PromptTemplate *prompt,
LLMCore::ContextData context,
LLMCore::RequestType type) override;
bool handleResponse(QNetworkReply *reply, QString &accumulatedResponse) override;
QList<QString> getInstalledModels(const QString &url) override;
QList<QString> validateRequest(const QJsonObject &request, LLMCore::TemplateType type) override;

View File

@ -33,8 +33,6 @@
namespace QodeAssist::Providers {
LMStudioProvider::LMStudioProvider() {}
QString LMStudioProvider::name() const
{
return "LM Studio";
@ -60,47 +58,6 @@ bool LMStudioProvider::supportsModelListing() const
return true;
}
void LMStudioProvider::prepareRequest(QJsonObject &request, LLMCore::RequestType type)
{
auto prepareMessages = [](QJsonObject &req) -> QJsonArray {
QJsonArray messages;
if (req.contains("system")) {
messages.append(
QJsonObject{{"role", "system"}, {"content", req.take("system").toString()}});
}
if (req.contains("prompt")) {
messages.append(
QJsonObject{{"role", "user"}, {"content", req.take("prompt").toString()}});
}
return messages;
};
auto applyModelParams = [&request](const auto &settings) {
request["max_tokens"] = settings.maxTokens();
request["temperature"] = settings.temperature();
if (settings.useTopP())
request["top_p"] = settings.topP();
if (settings.useTopK())
request["top_k"] = settings.topK();
if (settings.useFrequencyPenalty())
request["frequency_penalty"] = settings.frequencyPenalty();
if (settings.usePresencePenalty())
request["presence_penalty"] = settings.presencePenalty();
};
QJsonArray messages = prepareMessages(request);
if (!messages.isEmpty()) {
request["messages"] = std::move(messages);
}
if (type == LLMCore::RequestType::CodeCompletion) {
applyModelParams(Settings::codeCompletionSettings());
} else {
applyModelParams(Settings::chatAssistantSettings());
}
}
bool LMStudioProvider::handleResponse(QNetworkReply *reply, QString &accumulatedResponse)
{
QByteArray data = reply->readAll();
@ -211,4 +168,37 @@ void LMStudioProvider::prepareNetworkRequest(QNetworkRequest &networkRequest) co
networkRequest.setHeader(QNetworkRequest::ContentTypeHeader, "application/json");
}
void QodeAssist::Providers::LMStudioProvider::prepareRequest(
QJsonObject &request,
LLMCore::PromptTemplate *prompt,
LLMCore::ContextData context,
LLMCore::RequestType type)
{
// if (!isSupportedTemplate(prompt->name())) {
// LOG_MESSAGE(QString("Provider doesn't support %1 template").arg(prompt->name()));
// }
prompt->prepareRequest(request, context);
auto applyModelParams = [&request](const auto &settings) {
request["max_tokens"] = settings.maxTokens();
request["temperature"] = settings.temperature();
if (settings.useTopP())
request["top_p"] = settings.topP();
if (settings.useTopK())
request["top_k"] = settings.topK();
if (settings.useFrequencyPenalty())
request["frequency_penalty"] = settings.frequencyPenalty();
if (settings.usePresencePenalty())
request["presence_penalty"] = settings.presencePenalty();
};
if (type == LLMCore::RequestType::CodeCompletion) {
applyModelParams(Settings::codeCompletionSettings());
} else {
applyModelParams(Settings::chatAssistantSettings());
}
}
} // namespace QodeAssist::Providers

View File

@ -26,14 +26,16 @@ namespace QodeAssist::Providers {
class LMStudioProvider : public LLMCore::Provider
{
public:
LMStudioProvider();
QString name() const override;
QString url() const override;
QString completionEndpoint() const override;
QString chatEndpoint() const override;
bool supportsModelListing() const override;
void prepareRequest(QJsonObject &request, LLMCore::RequestType type) override;
void prepareRequest(
QJsonObject &request,
LLMCore::PromptTemplate *prompt,
LLMCore::ContextData context,
LLMCore::RequestType type) override;
bool handleResponse(QNetworkReply *reply, QString &accumulatedResponse) override;
QList<QString> getInstalledModels(const QString &url) override;
QList<QString> validateRequest(const QJsonObject &request, LLMCore::TemplateType type) override;

View File

@ -0,0 +1,215 @@
#include "MistralAIProvider.hpp"
#include "settings/ChatAssistantSettings.hpp"
#include "settings/CodeCompletionSettings.hpp"
#include "settings/ProviderSettings.hpp"
#include <QJsonArray>
#include <QJsonDocument>
#include <QJsonObject>
#include <QNetworkReply>
#include <QtCore/qeventloop.h>
#include "llmcore/OpenAIMessage.hpp"
#include "llmcore/ValidationUtils.hpp"
#include "logger/Logger.hpp"
namespace QodeAssist::Providers {
QString MistralAIProvider::name() const
{
return "Mistral AI";
}
QString MistralAIProvider::url() const
{
return "https://api.mistral.ai";
}
QString MistralAIProvider::completionEndpoint() const
{
return "/v1/fim/completions";
}
QString MistralAIProvider::chatEndpoint() const
{
return "/v1/chat/completions";
}
bool MistralAIProvider::supportsModelListing() const
{
return true;
}
bool MistralAIProvider::handleResponse(QNetworkReply *reply, QString &accumulatedResponse)
{
QByteArray data = reply->readAll();
if (data.isEmpty()) {
return false;
}
bool isDone = false;
QByteArrayList lines = data.split('\n');
for (const QByteArray &line : lines) {
if (line.trimmed().isEmpty()) {
continue;
}
if (line == "data: [DONE]") {
isDone = true;
continue;
}
QByteArray jsonData = line;
if (line.startsWith("data: ")) {
jsonData = line.mid(6);
}
QJsonParseError error;
QJsonDocument doc = QJsonDocument::fromJson(jsonData, &error);
if (doc.isNull()) {
continue;
}
auto message = LLMCore::OpenAIMessage::fromJson(doc.object());
if (message.hasError()) {
LOG_MESSAGE("Error in OpenAI response: " + message.error);
continue;
}
QString content = message.getContent();
if (!content.isEmpty()) {
accumulatedResponse += content;
}
if (message.isDone()) {
isDone = true;
}
}
return isDone;
}
QList<QString> MistralAIProvider::getInstalledModels(const QString &url)
{
QList<QString> models;
QNetworkAccessManager manager;
QNetworkRequest request(QString("%1/v1/models").arg(url));
request.setHeader(QNetworkRequest::ContentTypeHeader, "application/json");
if (!apiKey().isEmpty()) {
request.setRawHeader("Authorization", QString("Bearer %1").arg(apiKey()).toUtf8());
}
QNetworkReply *reply = manager.get(request);
QEventLoop loop;
QObject::connect(reply, &QNetworkReply::finished, &loop, &QEventLoop::quit);
loop.exec();
if (reply->error() == QNetworkReply::NoError) {
QByteArray responseData = reply->readAll();
QJsonDocument jsonResponse = QJsonDocument::fromJson(responseData);
QJsonObject jsonObject = jsonResponse.object();
if (jsonObject.contains("data") && jsonObject["object"].toString() == "list") {
QJsonArray modelArray = jsonObject["data"].toArray();
for (const QJsonValue &value : modelArray) {
QJsonObject modelObject = value.toObject();
if (modelObject.contains("id")) {
QString modelId = modelObject["id"].toString();
models.append(modelId);
}
}
}
} else {
LOG_MESSAGE(QString("Error fetching Mistral AI models: %1").arg(reply->errorString()));
}
reply->deleteLater();
return models;
}
QList<QString> MistralAIProvider::validateRequest(
const QJsonObject &request, LLMCore::TemplateType type)
{
const auto fimReq = QJsonObject{
{"model", {}},
{"max_tokens", {}},
{"stream", {}},
{"temperature", {}},
{"prompt", {}},
{"suffix", {}}};
const auto templateReq = QJsonObject{
{"model", {}},
{"messages", QJsonArray{{QJsonObject{{"role", {}}, {"content", {}}}}}},
{"temperature", {}},
{"max_tokens", {}},
{"top_p", {}},
{"frequency_penalty", {}},
{"presence_penalty", {}},
{"stop", QJsonArray{}},
{"stream", {}}};
return LLMCore::ValidationUtils::validateRequestFields(
request, type == LLMCore::TemplateType::FIM ? fimReq : templateReq);
}
QString MistralAIProvider::apiKey() const
{
return Settings::providerSettings().mistralAiApiKey();
}
void MistralAIProvider::prepareNetworkRequest(QNetworkRequest &networkRequest) const
{
networkRequest.setHeader(QNetworkRequest::ContentTypeHeader, "application/json");
if (!apiKey().isEmpty()) {
networkRequest.setRawHeader("Authorization", QString("Bearer %1").arg(apiKey()).toUtf8());
}
}
void MistralAIProvider::prepareRequest(
QJsonObject &request,
LLMCore::PromptTemplate *prompt,
LLMCore::ContextData context,
LLMCore::RequestType type)
{
// if (!isSupportedTemplate(prompt->name())) {
// LOG_MESSAGE(QString("Provider doesn't support %1 template").arg(prompt->name()));
// }
prompt->prepareRequest(request, context);
if (type == LLMCore::RequestType::Chat) {
auto &settings = Settings::chatAssistantSettings();
request["max_tokens"] = settings.maxTokens();
request["temperature"] = settings.temperature();
if (settings.useTopP())
request["top_p"] = settings.topP();
// request["random_seed"] = "";
if (settings.useFrequencyPenalty())
request["frequency_penalty"] = settings.frequencyPenalty();
if (settings.usePresencePenalty())
request["presence_penalty"] = settings.presencePenalty();
} else {
auto &settings = Settings::codeCompletionSettings();
request["max_tokens"] = settings.maxTokens();
request["temperature"] = settings.temperature();
if (settings.useTopP())
request["top_p"] = settings.topP();
// request["random_seed"] = "";
}
}
} // namespace QodeAssist::Providers

View File

@ -0,0 +1,46 @@
/*
* Copyright (C) 2024 Petr Mironychev
*
* This file is part of QodeAssist.
*
* QodeAssist is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* QodeAssist is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with QodeAssist. If not, see <https://www.gnu.org/licenses/>.
*/
#pragma once
#include "llmcore/Provider.hpp"
namespace QodeAssist::Providers {
class MistralAIProvider : public LLMCore::Provider
{
public:
QString name() const override;
QString url() const override;
QString completionEndpoint() const override;
QString chatEndpoint() const override;
bool supportsModelListing() const override;
void prepareRequest(
QJsonObject &request,
LLMCore::PromptTemplate *prompt,
LLMCore::ContextData context,
LLMCore::RequestType type) override;
bool handleResponse(QNetworkReply *reply, QString &accumulatedResponse) override;
QList<QString> getInstalledModels(const QString &url) override;
QList<QString> validateRequest(const QJsonObject &request, LLMCore::TemplateType type) override;
QString apiKey() const override;
void prepareNetworkRequest(QNetworkRequest &networkRequest) const override;
};
} // namespace QodeAssist::Providers

View File

@ -33,8 +33,6 @@
namespace QodeAssist::Providers {
OllamaProvider::OllamaProvider() {}
QString OllamaProvider::name() const
{
return "Ollama";
@ -60,8 +58,18 @@ bool OllamaProvider::supportsModelListing() const
return true;
}
void OllamaProvider::prepareRequest(QJsonObject &request, LLMCore::RequestType type)
void OllamaProvider::prepareRequest(
QJsonObject &request,
LLMCore::PromptTemplate *prompt,
LLMCore::ContextData context,
LLMCore::RequestType type)
{
// if (!isSupportedTemplate(prompt->name())) {
// LOG_MESSAGE(QString("Provider doesn't support %1 template").arg(prompt->name()));
// }
prompt->prepareRequest(request, context);
auto applySettings = [&request](const auto &settings) {
QJsonObject options;
options["num_predict"] = settings.maxTokens();
@ -192,7 +200,7 @@ QList<QString> OllamaProvider::validateRequest(const QJsonObject &request, LLMCo
{"presence_penalty", {}}}}};
return LLMCore::ValidationUtils::validateRequestFields(
request, type == LLMCore::TemplateType::Fim ? fimReq : messageReq);
request, type == LLMCore::TemplateType::FIM ? fimReq : messageReq);
}
QString OllamaProvider::apiKey() const
@ -203,6 +211,6 @@ QString OllamaProvider::apiKey() const
void OllamaProvider::prepareNetworkRequest(QNetworkRequest &networkRequest) const
{
networkRequest.setHeader(QNetworkRequest::ContentTypeHeader, "application/json");
};
}
} // namespace QodeAssist::Providers

View File

@ -26,14 +26,16 @@ namespace QodeAssist::Providers {
class OllamaProvider : public LLMCore::Provider
{
public:
OllamaProvider();
QString name() const override;
QString url() const override;
QString completionEndpoint() const override;
QString chatEndpoint() const override;
bool supportsModelListing() const override;
void prepareRequest(QJsonObject &request, LLMCore::RequestType type) override;
void prepareRequest(
QJsonObject &request,
LLMCore::PromptTemplate *prompt,
LLMCore::ContextData context,
LLMCore::RequestType type) override;
bool handleResponse(QNetworkReply *reply, QString &accumulatedResponse) override;
QList<QString> getInstalledModels(const QString &url) override;
QList<QString> validateRequest(const QJsonObject &request, LLMCore::TemplateType type) override;

View File

@ -34,8 +34,6 @@
namespace QodeAssist::Providers {
OpenAICompatProvider::OpenAICompatProvider() {}
QString OpenAICompatProvider::name() const
{
return "OpenAI Compatible";
@ -61,20 +59,17 @@ bool OpenAICompatProvider::supportsModelListing() const
return false;
}
void OpenAICompatProvider::prepareRequest(QJsonObject &request, LLMCore::RequestType type)
void OpenAICompatProvider::prepareRequest(
QJsonObject &request,
LLMCore::PromptTemplate *prompt,
LLMCore::ContextData context,
LLMCore::RequestType type)
{
auto prepareMessages = [](QJsonObject &req) -> QJsonArray {
QJsonArray messages;
if (req.contains("system")) {
messages.append(
QJsonObject{{"role", "system"}, {"content", req.take("system").toString()}});
}
if (req.contains("prompt")) {
messages.append(
QJsonObject{{"role", "user"}, {"content", req.take("prompt").toString()}});
}
return messages;
};
// if (!isSupportedTemplate(prompt->name())) {
// LOG_MESSAGE(QString("Provider doesn't support %1 template").arg(prompt->name()));
// }
prompt->prepareRequest(request, context);
auto applyModelParams = [&request](const auto &settings) {
request["max_tokens"] = settings.maxTokens();
@ -90,11 +85,6 @@ void OpenAICompatProvider::prepareRequest(QJsonObject &request, LLMCore::Request
request["presence_penalty"] = settings.presencePenalty();
};
QJsonArray messages = prepareMessages(request);
if (!messages.isEmpty()) {
request["messages"] = std::move(messages);
}
if (type == LLMCore::RequestType::CodeCompletion) {
applyModelParams(Settings::codeCompletionSettings());
} else {

View File

@ -26,14 +26,16 @@ namespace QodeAssist::Providers {
class OpenAICompatProvider : public LLMCore::Provider
{
public:
OpenAICompatProvider();
QString name() const override;
QString url() const override;
QString completionEndpoint() const override;
QString chatEndpoint() const override;
bool supportsModelListing() const override;
void prepareRequest(QJsonObject &request, LLMCore::RequestType type) override;
void prepareRequest(
QJsonObject &request,
LLMCore::PromptTemplate *prompt,
LLMCore::ContextData context,
LLMCore::RequestType type) override;
bool handleResponse(QNetworkReply *reply, QString &accumulatedResponse) override;
QList<QString> getInstalledModels(const QString &url) override;
QList<QString> validateRequest(const QJsonObject &request, LLMCore::TemplateType type) override;

View File

@ -35,8 +35,6 @@
namespace QodeAssist::Providers {
OpenAIProvider::OpenAIProvider() {}
QString OpenAIProvider::name() const
{
return "OpenAI";
@ -62,20 +60,17 @@ bool OpenAIProvider::supportsModelListing() const
return true;
}
void OpenAIProvider::prepareRequest(QJsonObject &request, LLMCore::RequestType type)
void OpenAIProvider::prepareRequest(
QJsonObject &request,
LLMCore::PromptTemplate *prompt,
LLMCore::ContextData context,
LLMCore::RequestType type)
{
auto prepareMessages = [](QJsonObject &req) -> QJsonArray {
QJsonArray messages;
if (req.contains("system")) {
messages.append(
QJsonObject{{"role", "system"}, {"content", req.take("system").toString()}});
}
if (req.contains("prompt")) {
messages.append(
QJsonObject{{"role", "user"}, {"content", req.take("prompt").toString()}});
}
return messages;
};
// if (!isSupportedTemplate(prompt->name())) {
// LOG_MESSAGE(QString("Provider doesn't support %1 template").arg(prompt->name()));
// }
prompt->prepareRequest(request, context);
auto applyModelParams = [&request](const auto &settings) {
request["max_tokens"] = settings.maxTokens();
@ -91,11 +86,6 @@ void OpenAIProvider::prepareRequest(QJsonObject &request, LLMCore::RequestType t
request["presence_penalty"] = settings.presencePenalty();
};
QJsonArray messages = prepareMessages(request);
if (!messages.isEmpty()) {
request["messages"] = std::move(messages);
}
if (type == LLMCore::RequestType::CodeCompletion) {
applyModelParams(Settings::codeCompletionSettings());
} else {

View File

@ -26,14 +26,16 @@ namespace QodeAssist::Providers {
class OpenAIProvider : public LLMCore::Provider
{
public:
OpenAIProvider();
QString name() const override;
QString url() const override;
QString completionEndpoint() const override;
QString chatEndpoint() const override;
bool supportsModelListing() const override;
void prepareRequest(QJsonObject &request, LLMCore::RequestType type) override;
void prepareRequest(
QJsonObject &request,
LLMCore::PromptTemplate *prompt,
LLMCore::ContextData context,
LLMCore::RequestType type) override;
bool handleResponse(QNetworkReply *reply, QString &accumulatedResponse) override;
QList<QString> getInstalledModels(const QString &url) override;
QList<QString> validateRequest(const QJsonObject &request, LLMCore::TemplateType type) override;

View File

@ -33,8 +33,6 @@
namespace QodeAssist::Providers {
OpenRouterProvider::OpenRouterProvider() {}
QString OpenRouterProvider::name() const
{
return "OpenRouter";
@ -45,47 +43,6 @@ QString OpenRouterProvider::url() const
return "https://openrouter.ai/api";
}
void OpenRouterProvider::prepareRequest(QJsonObject &request, LLMCore::RequestType type)
{
auto prepareMessages = [](QJsonObject &req) -> QJsonArray {
QJsonArray messages;
if (req.contains("system")) {
messages.append(
QJsonObject{{"role", "system"}, {"content", req.take("system").toString()}});
}
if (req.contains("prompt")) {
messages.append(
QJsonObject{{"role", "user"}, {"content", req.take("prompt").toString()}});
}
return messages;
};
auto applyModelParams = [&request](const auto &settings) {
request["max_tokens"] = settings.maxTokens();
request["temperature"] = settings.temperature();
if (settings.useTopP())
request["top_p"] = settings.topP();
if (settings.useTopK())
request["top_k"] = settings.topK();
if (settings.useFrequencyPenalty())
request["frequency_penalty"] = settings.frequencyPenalty();
if (settings.usePresencePenalty())
request["presence_penalty"] = settings.presencePenalty();
};
QJsonArray messages = prepareMessages(request);
if (!messages.isEmpty()) {
request["messages"] = std::move(messages);
}
if (type == LLMCore::RequestType::CodeCompletion) {
applyModelParams(Settings::codeCompletionSettings());
} else {
applyModelParams(Settings::chatAssistantSettings());
}
}
bool OpenRouterProvider::handleResponse(QNetworkReply *reply, QString &accumulatedResponse)
{
QByteArray data = reply->readAll();

View File

@ -27,11 +27,8 @@ namespace QodeAssist::Providers {
class OpenRouterProvider : public OpenAICompatProvider
{
public:
OpenRouterProvider();
QString name() const override;
QString url() const override;
void prepareRequest(QJsonObject &request, LLMCore::RequestType type) override;
bool handleResponse(QNetworkReply *reply, QString &accumulatedResponse) override;
QString apiKey() const override;
};

View File

@ -22,6 +22,7 @@
#include "llmcore/ProvidersManager.hpp"
#include "providers/ClaudeProvider.hpp"
#include "providers/LMStudioProvider.hpp"
#include "providers/MistralAIProvider.hpp"
#include "providers/OllamaProvider.hpp"
#include "providers/OpenAICompatProvider.hpp"
#include "providers/OpenAIProvider.hpp"
@ -33,11 +34,12 @@ inline void registerProviders()
{
auto &providerManager = LLMCore::ProvidersManager::instance();
providerManager.registerProvider<OllamaProvider>();
providerManager.registerProvider<LMStudioProvider>();
providerManager.registerProvider<OpenAICompatProvider>();
providerManager.registerProvider<OpenRouterProvider>();
providerManager.registerProvider<ClaudeProvider>();
providerManager.registerProvider<OpenAIProvider>();
providerManager.registerProvider<OpenAICompatProvider>();
providerManager.registerProvider<LMStudioProvider>();
providerManager.registerProvider<OpenRouterProvider>();
providerManager.registerProvider<MistralAIProvider>();
}
} // namespace QodeAssist::Providers

View File

@ -69,6 +69,7 @@ ProviderSettings::ProviderSettings()
claudeApiKey.setDefaultValue("");
claudeApiKey.setAutoApply(true);
// OpenAI Settings
openAiApiKey.setSettingsKey(Constants::OPEN_AI_API_KEY);
openAiApiKey.setLabelText(Tr::tr("OpenAI API Key:"));
openAiApiKey.setDisplayStyle(Utils::StringAspect::LineEditDisplay);
@ -77,6 +78,15 @@ ProviderSettings::ProviderSettings()
openAiApiKey.setDefaultValue("");
openAiApiKey.setAutoApply(true);
// MistralAI Settings
mistralAiApiKey.setSettingsKey(Constants::MISTRAL_AI_API_KEY);
mistralAiApiKey.setLabelText(Tr::tr("Mistral AI API Key:"));
mistralAiApiKey.setDisplayStyle(Utils::StringAspect::LineEditDisplay);
mistralAiApiKey.setPlaceHolderText(Tr::tr("Enter your API key here"));
mistralAiApiKey.setHistoryCompleter(Constants::MISTRAL_AI_API_KEY_HISTORY);
mistralAiApiKey.setDefaultValue("");
mistralAiApiKey.setAutoApply(true);
resetToDefaults.m_buttonText = Tr::tr("Reset Page to Defaults");
readSettings();
@ -96,6 +106,8 @@ ProviderSettings::ProviderSettings()
Group{title(Tr::tr("OpenAI Compatible Settings")), Column{openAiCompatApiKey}},
Space{8},
Group{title(Tr::tr("Claude Settings")), Column{claudeApiKey}},
Space{8},
Group{title(Tr::tr("Mistral AI Settings")), Column{mistralAiApiKey}},
Stretch{1}};
});
}

View File

@ -37,6 +37,7 @@ public:
Utils::StringAspect openAiCompatApiKey{this};
Utils::StringAspect claudeApiKey{this};
Utils::StringAspect openAiApiKey{this};
Utils::StringAspect mistralAiApiKey{this};
private:
void setupConnections();

View File

@ -100,6 +100,8 @@ const char CLAUDE_API_KEY[] = "QodeAssist.claudeApiKey";
const char CLAUDE_API_KEY_HISTORY[] = "QodeAssist.claudeApiKeyHistory";
const char OPEN_AI_API_KEY[] = "QodeAssist.openAiApiKey";
const char OPEN_AI_API_KEY_HISTORY[] = "QodeAssist.openAiApiKeyHistory";
const char MISTRAL_AI_API_KEY[] = "QodeAssist.mistralAiApiKey";
const char MISTRAL_AI_API_KEY_HISTORY[] = "QodeAssist.mistralAiApiKeyHistory";
// context settings
const char CC_READ_FULL_FILE[] = "QodeAssist.ccReadFullFile";

View File

@ -29,33 +29,32 @@ class Alpaca : public LLMCore::PromptTemplate
public:
QString name() const override { return "Alpaca"; }
LLMCore::TemplateType type() const override { return LLMCore::TemplateType::Chat; }
QString promptTemplate() const override { return {}; }
QStringList stopWords() const override
{
return QStringList() << "### Instruction:" << "### Response:";
}
void prepareRequest(QJsonObject &request, const LLMCore::ContextData &context) const override
{
QJsonArray messages = request["messages"].toArray();
QJsonArray messages;
for (int i = 0; i < messages.size(); ++i) {
QJsonObject message = messages[i].toObject();
QString role = message["role"].toString();
QString content = message["content"].toString();
QString fullContent;
QString formattedContent;
if (role == "system") {
formattedContent = content + "\n\n";
} else if (role == "user") {
formattedContent = "### Instruction:\n" + content + "\n\n";
} else if (role == "assistant") {
formattedContent = "### Response:\n" + content + "\n\n";
}
message["content"] = formattedContent;
messages[i] = message;
if (context.systemPrompt) {
fullContent += context.systemPrompt.value() + "\n\n";
}
if (context.history) {
for (const auto &msg : context.history.value()) {
if (msg.role == "user") {
fullContent += QString("### Instruction:\n%1\n\n").arg(msg.content);
} else if (msg.role == "assistant") {
fullContent += QString("### Response:\n%1\n\n").arg(msg.content);
}
}
}
messages.append(QJsonObject{{"role", "user"}, {"content", fullContent}});
request["messages"] = messages;
}
QString description() const override

View File

@ -30,23 +30,28 @@ class ChatML : public LLMCore::PromptTemplate
public:
QString name() const override { return "ChatML"; }
LLMCore::TemplateType type() const override { return LLMCore::TemplateType::Chat; }
QString promptTemplate() const override { return {}; }
QStringList stopWords() const override
{
return QStringList() << "<|im_start|>" << "<|im_end|>";
}
void prepareRequest(QJsonObject &request, const LLMCore::ContextData &context) const override
{
QJsonArray messages = request["messages"].toArray();
QJsonArray messages;
for (int i = 0; i < messages.size(); ++i) {
QJsonObject message = messages[i].toObject();
QString role = message["role"].toString();
QString content = message["content"].toString();
if (context.systemPrompt) {
messages.append(QJsonObject{
{"role", "system"},
{"content",
QString("<|im_start|>system\n%2\n<|im_end|>").arg(context.systemPrompt.value())}});
}
message["content"] = QString("<|im_start|>%1\n%2\n<|im_end|>").arg(role, content);
messages[i] = message;
if (context.history) {
for (const auto &msg : context.history.value()) {
messages.append(QJsonObject{
{"role", msg.role},
{"content",
QString("<|im_start|>%1\n%2\n<|im_end|>").arg(msg.role, msg.content)}});
}
}
request["messages"] = messages;

View File

@ -30,9 +30,25 @@ class Claude : public LLMCore::PromptTemplate
public:
LLMCore::TemplateType type() const override { return LLMCore::TemplateType::Chat; }
QString name() const override { return "Claude"; }
QString promptTemplate() const override { return {}; }
QStringList stopWords() const override { return QStringList(); }
void prepareRequest(QJsonObject &request, const LLMCore::ContextData &context) const override {}
void prepareRequest(QJsonObject &request, const LLMCore::ContextData &context) const override
{
QJsonArray messages;
if (context.systemPrompt) {
request["system"] = context.systemPrompt.value();
}
if (context.history) {
for (const auto &msg : context.history.value()) {
if (msg.role != "system") {
messages.append(QJsonObject{{"role", msg.role}, {"content", msg.content}});
}
}
}
request["messages"] = messages;
}
QString description() const override { return "Claude"; }
};

View File

@ -26,17 +26,17 @@ namespace QodeAssist::Templates {
class CodeLlamaFim : public LLMCore::PromptTemplate
{
public:
LLMCore::TemplateType type() const override { return LLMCore::TemplateType::Fim; }
LLMCore::TemplateType type() const override { return LLMCore::TemplateType::FIM; }
QString name() const override { return "CodeLlama FIM"; }
QString promptTemplate() const override { return "<PRE> %1 <SUF>%2 <MID>"; }
QStringList stopWords() const override
{
return QStringList() << "<EOT>" << "<PRE>" << "<SUF" << "<MID>";
}
void prepareRequest(QJsonObject &request, const LLMCore::ContextData &context) const override
{
QString formattedPrompt = promptTemplate().arg(context.prefix, context.suffix);
request["prompt"] = formattedPrompt;
request["prompt"] = QString("<PRE> %1 <SUF>%2 <MID>")
.arg(context.prefix.value_or(""), context.suffix.value_or(""));
request["system"] = context.systemPrompt.value_or("");
}
QString description() const override
{

View File

@ -26,9 +26,8 @@ namespace QodeAssist::Templates {
class CodeLlamaQMLFim : public LLMCore::PromptTemplate
{
public:
LLMCore::TemplateType type() const override { return LLMCore::TemplateType::Fim; }
LLMCore::TemplateType type() const override { return LLMCore::TemplateType::FIM; }
QString name() const override { return "CodeLlama QML FIM"; }
QString promptTemplate() const override { return "<SUF>%1<PRE>%2<MID>"; }
QStringList stopWords() const override
{
return QStringList() << "<SUF>" << "<PRE>" << "</PRE>" << "</SUF>" << "< EOT >" << "\\end"
@ -36,8 +35,9 @@ public:
}
void prepareRequest(QJsonObject &request, const LLMCore::ContextData &context) const override
{
QString formattedPrompt = promptTemplate().arg(context.suffix, context.prefix);
request["prompt"] = formattedPrompt;
request["prompt"] = QString("<SUF>%1<PRE>%2<MID>")
.arg(context.suffix.value_or(""), context.prefix.value_or(""));
request["system"] = context.systemPrompt.value_or("");
}
QString description() const override
{

View File

@ -29,30 +29,30 @@ class Llama2 : public LLMCore::PromptTemplate
public:
QString name() const override { return "Llama 2"; }
LLMCore::TemplateType type() const override { return LLMCore::TemplateType::Chat; }
QString promptTemplate() const override { return {}; }
QStringList stopWords() const override { return QStringList() << "[INST]"; }
void prepareRequest(QJsonObject &request, const LLMCore::ContextData &context) const override
{
QJsonArray messages = request["messages"].toArray();
QJsonArray messages;
for (int i = 0; i < messages.size(); ++i) {
QJsonObject message = messages[i].toObject();
QString role = message["role"].toString();
QString content = message["content"].toString();
QString fullContent;
QString formattedContent;
if (role == "system") {
formattedContent = QString("[INST]<<SYS>>\n%1\n<</SYS>>[/INST]\n").arg(content);
} else if (role == "user") {
formattedContent = QString("[INST]%1[/INST]\n").arg(content);
} else if (role == "assistant") {
formattedContent = content + "\n";
}
message["content"] = formattedContent;
messages[i] = message;
if (context.systemPrompt) {
fullContent
+= QString("[INST]<<SYS>>\n%1\n<</SYS>>[/INST]\n").arg(context.systemPrompt.value());
}
if (context.history) {
for (const auto &msg : context.history.value()) {
if (msg.role == "user") {
fullContent += QString("[INST]%1[/INST]\n").arg(msg.content);
} else if (msg.role == "assistant") {
fullContent += msg.content + "\n";
}
}
}
messages.append(QJsonObject{{"role", "user"}, {"content", fullContent}});
request["messages"] = messages;
}
QString description() const override

View File

@ -30,24 +30,30 @@ class Llama3 : public LLMCore::PromptTemplate
public:
QString name() const override { return "Llama 3"; }
LLMCore::TemplateType type() const override { return LLMCore::TemplateType::Chat; }
QString promptTemplate() const override { return ""; }
QStringList stopWords() const override
{
return QStringList() << "<|start_header_id|>" << "<|end_header_id|>" << "<|eot_id|>";
}
void prepareRequest(QJsonObject &request, const LLMCore::ContextData &context) const override
{
QJsonArray messages = request["messages"].toArray();
QJsonArray messages;
for (int i = 0; i < messages.size(); ++i) {
QJsonObject message = messages[i].toObject();
QString role = message["role"].toString();
QString content = message["content"].toString();
if (context.systemPrompt) {
messages.append(QJsonObject{
{"role", "system"},
{"content",
QString("<|start_header_id|>system<|end_header_id|>%2<|eot_id|>")
.arg(context.systemPrompt.value())}});
}
message["content"]
= QString("<|start_header_id|>%1<|end_header_id|>%2<|eot_id|>").arg(role, content);
messages[i] = message;
if (context.history) {
for (const auto &msg : context.history.value()) {
messages.append(QJsonObject{
{"role", msg.role},
{"content",
QString("<|start_header_id|>%1<|end_header_id|>%2<|eot_id|>")
.arg(msg.role, msg.content)}});
}
}
request["messages"] = messages;

69
templates/MistralAI.hpp Normal file
View File

@ -0,0 +1,69 @@
/*
* Copyright (C) 2024 Petr Mironychev
*
* This file is part of QodeAssist.
*
* QodeAssist is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* QodeAssist is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with QodeAssist. If not, see <https://www.gnu.org/licenses/>.
*/
#pragma once
#include <QJsonArray>
#include "llmcore/PromptTemplate.hpp"
namespace QodeAssist::Templates {
class MistralAIFim : public LLMCore::PromptTemplate
{
public:
LLMCore::TemplateType type() const override { return LLMCore::TemplateType::FIM; }
QString name() const override { return "Mistral AI FIM"; }
QStringList stopWords() const override { return QStringList(); }
void prepareRequest(QJsonObject &request, const LLMCore::ContextData &context) const override
{
request["prompt"] = context.prefix.value_or("");
request["suffix"] = context.suffix.value_or("");
}
QString description() const override { return "template will take from ollama modelfile"; }
};
class MistralAIChat : public LLMCore::PromptTemplate
{
public:
LLMCore::TemplateType type() const override { return LLMCore::TemplateType::Chat; }
QString name() const override { return "Mistral AI Chat"; }
QStringList stopWords() const override { return QStringList(); }
void prepareRequest(QJsonObject &request, const LLMCore::ContextData &context) const override
{
QJsonArray messages;
if (context.systemPrompt) {
messages.append(
QJsonObject{{"role", "system"}, {"content", context.systemPrompt.value()}});
}
if (context.history) {
for (const auto &msg : context.history.value()) {
messages.append(QJsonObject{{"role", msg.role}, {"content", msg.content}});
}
}
request["messages"] = messages;
}
QString description() const override { return "template will take from ollama modelfile"; }
};
} // namespace QodeAssist::Templates

View File

@ -25,31 +25,44 @@
namespace QodeAssist::Templates {
class OllamaAutoFim : public LLMCore::PromptTemplate
class OllamaFim : public LLMCore::PromptTemplate
{
public:
LLMCore::TemplateType type() const override { return LLMCore::TemplateType::Fim; }
QString name() const override { return "Ollama Auto FIM"; }
QString promptTemplate() const override { return {}; }
LLMCore::TemplateType type() const override { return LLMCore::TemplateType::FIM; }
QString name() const override { return "Ollama FIM"; }
QStringList stopWords() const override { return QStringList(); }
void prepareRequest(QJsonObject &request, const LLMCore::ContextData &context) const override
{
request["prompt"] = context.prefix;
request["suffix"] = context.suffix;
request["prompt"] = context.prefix.value_or("");
request["suffix"] = context.suffix.value_or("");
request["system"] = context.systemPrompt.value_or("");
}
QString description() const override { return "template will take from ollama modelfile"; }
};
class OllamaAutoChat : public LLMCore::PromptTemplate
class OllamaChat : public LLMCore::PromptTemplate
{
public:
LLMCore::TemplateType type() const override { return LLMCore::TemplateType::Chat; }
QString name() const override { return "Ollama Auto Chat"; }
QString promptTemplate() const override { return {}; }
QString name() const override { return "Ollama Chat"; }
QStringList stopWords() const override { return QStringList(); }
void prepareRequest(QJsonObject &request, const LLMCore::ContextData &context) const override
{
QJsonArray messages;
if (context.systemPrompt) {
messages.append(
QJsonObject{{"role", "system"}, {"content", context.systemPrompt.value()}});
}
if (context.history) {
for (const auto &msg : context.history.value()) {
messages.append(QJsonObject{{"role", msg.role}, {"content", msg.content}});
}
}
request["messages"] = messages;
}
QString description() const override { return "template will take from ollama modelfile"; }
};

View File

@ -30,9 +30,24 @@ class OpenAI : public LLMCore::PromptTemplate
public:
LLMCore::TemplateType type() const override { return LLMCore::TemplateType::Chat; }
QString name() const override { return "OpenAI"; }
QString promptTemplate() const override { return {}; }
QStringList stopWords() const override { return QStringList(); }
void prepareRequest(QJsonObject &request, const LLMCore::ContextData &context) const override {}
void prepareRequest(QJsonObject &request, const LLMCore::ContextData &context) const override
{
QJsonArray messages;
if (context.systemPrompt) {
messages.append(
QJsonObject{{"role", "system"}, {"content", context.systemPrompt.value()}});
}
if (context.history) {
for (const auto &msg : context.history.value()) {
messages.append(QJsonObject{{"role", msg.role}, {"content", msg.content}});
}
}
request["messages"] = messages;
}
QString description() const override { return "OpenAI"; }
};

View File

@ -25,15 +25,29 @@
namespace QodeAssist::Templates {
class BasicChat : public LLMCore::PromptTemplate
class OpenAICompatible : public LLMCore::PromptTemplate
{
public:
LLMCore::TemplateType type() const override { return LLMCore::TemplateType::Chat; }
QString name() const override { return "Basic Chat"; }
QString promptTemplate() const override { return {}; }
QString name() const override { return "OpenAI Compatible"; }
QStringList stopWords() const override { return QStringList(); }
void prepareRequest(QJsonObject &request, const LLMCore::ContextData &context) const override
{}
{
QJsonArray messages;
if (context.systemPrompt) {
messages.append(
QJsonObject{{"role", "system"}, {"content", context.systemPrompt.value()}});
}
if (context.history) {
for (const auto &msg : context.history.value()) {
messages.append(QJsonObject{{"role", msg.role}, {"content", msg.content}});
}
}
request["messages"] = messages;
}
QString description() const override { return "chat without tokens"; }
};

View File

@ -28,16 +28,13 @@ class QwenFim : public LLMCore::PromptTemplate
{
public:
QString name() const override { return "Qwen FIM"; }
LLMCore::TemplateType type() const override { return LLMCore::TemplateType::Fim; }
QString promptTemplate() const override
{
return "<|fim_prefix|>%1<|fim_suffix|>%2<|fim_middle|>";
}
LLMCore::TemplateType type() const override { return LLMCore::TemplateType::FIM; }
QStringList stopWords() const override { return QStringList() << "<|endoftext|>" << "<|EOT|>"; }
void prepareRequest(QJsonObject &request, const LLMCore::ContextData &context) const override
{
QString formattedPrompt = promptTemplate().arg(context.prefix, context.suffix);
request["prompt"] = formattedPrompt;
request["prompt"] = QString("<|fim_prefix|>%1<|fim_suffix|>%2<|fim_middle|>")
.arg(context.prefix.value_or(""), context.suffix.value_or(""));
request["system"] = context.systemPrompt.value_or("");
}
QString description() const override
{

View File

@ -26,9 +26,8 @@ namespace QodeAssist::Templates {
class StarCoder2Fim : public LLMCore::PromptTemplate
{
public:
LLMCore::TemplateType type() const override { return LLMCore::TemplateType::Fim; }
LLMCore::TemplateType type() const override { return LLMCore::TemplateType::FIM; }
QString name() const override { return "StarCoder2 FIM"; }
QString promptTemplate() const override { return "<fim_prefix>%1<fim_suffix>%2<fim_middle>"; }
QStringList stopWords() const override
{
return QStringList() << "<|endoftext|>" << "<file_sep>" << "<fim_prefix>" << "<fim_suffix>"
@ -36,8 +35,9 @@ public:
}
void prepareRequest(QJsonObject &request, const LLMCore::ContextData &context) const override
{
QString formattedPrompt = promptTemplate().arg(context.prefix, context.suffix);
request["prompt"] = formattedPrompt;
request["prompt"] = QString("<fim_prefix>%1<fim_suffix>%2<fim_middle>")
.arg(context.prefix.value_or(""), context.suffix.value_or(""));
request["system"] = context.systemPrompt.value_or("");
}
QString description() const override
{

View File

@ -21,17 +21,18 @@
#include "llmcore/PromptTemplateManager.hpp"
#include "templates/Alpaca.hpp"
#include "templates/BasicChat.hpp"
#include "templates/ChatML.hpp"
#include "templates/Claude.hpp"
#include "templates/CodeLlamaFim.hpp"
#include "templates/CodeLlamaQMLFim.hpp"
#include "templates/CustomFimTemplate.hpp"
#include "templates/DeepSeekCoderFim.hpp"
#include "templates/Llama2.hpp"
#include "templates/Llama3.hpp"
#include "templates/MistralAI.hpp"
#include "templates/Ollama.hpp"
#include "templates/OpenAI.hpp"
#include "templates/OpenAICompatible.hpp"
// #include "templates/CustomFimTemplate.hpp"
// #include "templates/DeepSeekCoderFim.hpp"
#include "templates/Llama2.hpp"
#include "templates/Llama3.hpp"
#include "templates/Qwen.hpp"
#include "templates/StarCoder2Fim.hpp"
@ -41,20 +42,22 @@ inline void registerTemplates()
{
auto &templateManager = LLMCore::PromptTemplateManager::instance();
templateManager.registerTemplate<CodeLlamaFim>();
templateManager.registerTemplate<StarCoder2Fim>();
templateManager.registerTemplate<DeepSeekCoderFim>();
templateManager.registerTemplate<CustomTemplate>();
templateManager.registerTemplate<QwenFim>();
templateManager.registerTemplate<OllamaAutoFim>();
templateManager.registerTemplate<OllamaAutoChat>();
templateManager.registerTemplate<BasicChat>();
templateManager.registerTemplate<Llama3>();
templateManager.registerTemplate<ChatML>();
templateManager.registerTemplate<Alpaca>();
templateManager.registerTemplate<Llama2>();
templateManager.registerTemplate<OllamaFim>();
templateManager.registerTemplate<OllamaChat>();
templateManager.registerTemplate<Claude>();
templateManager.registerTemplate<OpenAI>();
templateManager.registerTemplate<MistralAIFim>();
templateManager.registerTemplate<MistralAIChat>();
templateManager.registerTemplate<CodeLlamaQMLFim>();
templateManager.registerTemplate<ChatML>();
templateManager.registerTemplate<Llama2>();
templateManager.registerTemplate<Llama3>();
templateManager.registerTemplate<StarCoder2Fim>();
// templateManager.registerTemplate<DeepSeekCoderFim>();
// templateManager.registerTemplate<CustomTemplate>();
templateManager.registerTemplate<QwenFim>();
templateManager.registerTemplate<OpenAICompatible>();
templateManager.registerTemplate<Alpaca>();
}
} // namespace QodeAssist::Templates