mirror of
https://github.com/Palm1r/QodeAssist.git
synced 2025-06-04 01:28:58 -04:00
refactor: Rework providers and templates logic
This commit is contained in:
parent
bd25736a55
commit
d96f44d42c
@ -43,26 +43,28 @@ add_qtc_plugin(QodeAssist
|
||||
LLMClientInterface.hpp LLMClientInterface.cpp
|
||||
templates/Templates.hpp
|
||||
templates/CodeLlamaFim.hpp
|
||||
templates/StarCoder2Fim.hpp
|
||||
templates/DeepSeekCoderFim.hpp
|
||||
templates/CustomFimTemplate.hpp
|
||||
templates/Qwen.hpp
|
||||
templates/Ollama.hpp
|
||||
templates/BasicChat.hpp
|
||||
templates/Claude.hpp
|
||||
templates/OpenAI.hpp
|
||||
templates/MistralAI.hpp
|
||||
templates/StarCoder2Fim.hpp
|
||||
# templates/DeepSeekCoderFim.hpp
|
||||
# templates/CustomFimTemplate.hpp
|
||||
templates/Qwen.hpp
|
||||
templates/OpenAICompatible.hpp
|
||||
templates/Llama3.hpp
|
||||
templates/ChatML.hpp
|
||||
templates/Alpaca.hpp
|
||||
templates/Llama2.hpp
|
||||
templates/Claude.hpp
|
||||
templates/OpenAI.hpp
|
||||
templates/CodeLlamaQMLFim.hpp
|
||||
providers/Providers.hpp
|
||||
providers/OllamaProvider.hpp providers/OllamaProvider.cpp
|
||||
providers/ClaudeProvider.hpp providers/ClaudeProvider.cpp
|
||||
providers/OpenAIProvider.hpp providers/OpenAIProvider.cpp
|
||||
providers/MistralAIProvider.hpp providers/MistralAIProvider.cpp
|
||||
providers/LMStudioProvider.hpp providers/LMStudioProvider.cpp
|
||||
providers/OpenAICompatProvider.hpp providers/OpenAICompatProvider.cpp
|
||||
providers/OpenRouterAIProvider.hpp providers/OpenRouterAIProvider.cpp
|
||||
providers/ClaudeProvider.hpp providers/ClaudeProvider.cpp
|
||||
providers/OpenAIProvider.hpp providers/OpenAIProvider.cpp
|
||||
QodeAssist.qrc
|
||||
LSPCompletion.hpp
|
||||
LLMSuggestion.hpp LLMSuggestion.cpp
|
||||
|
@ -93,51 +93,35 @@ void ClientInterface::sendMessage(
|
||||
}
|
||||
|
||||
LLMCore::ContextData context;
|
||||
context.prefix = message;
|
||||
context.suffix = "";
|
||||
|
||||
QString systemPrompt;
|
||||
if (chatAssistantSettings.useSystemPrompt())
|
||||
systemPrompt = chatAssistantSettings.systemPrompt();
|
||||
|
||||
if (!linkedFiles.isEmpty()) {
|
||||
systemPrompt = getSystemPromptWithLinkedFiles(systemPrompt, linkedFiles);
|
||||
if (chatAssistantSettings.useSystemPrompt()) {
|
||||
QString systemPrompt = chatAssistantSettings.systemPrompt();
|
||||
if (!linkedFiles.isEmpty()) {
|
||||
systemPrompt = getSystemPromptWithLinkedFiles(systemPrompt, linkedFiles);
|
||||
}
|
||||
context.systemPrompt = systemPrompt;
|
||||
}
|
||||
|
||||
QJsonObject providerRequest;
|
||||
providerRequest["model"] = Settings::generalSettings().caModel();
|
||||
providerRequest["stream"] = chatAssistantSettings.stream();
|
||||
providerRequest["messages"] = m_chatModel->prepareMessagesForRequest(systemPrompt);
|
||||
|
||||
if (promptTemplate)
|
||||
promptTemplate->prepareRequest(providerRequest, context);
|
||||
else
|
||||
qWarning("No prompt template found");
|
||||
|
||||
if (provider)
|
||||
provider->prepareRequest(providerRequest, LLMCore::RequestType::Chat);
|
||||
else
|
||||
qWarning("No provider found");
|
||||
QVector<LLMCore::Message> messages;
|
||||
for (const auto &msg : m_chatModel->getChatHistory()) {
|
||||
messages.append({msg.role == ChatModel::ChatRole::User ? "user" : "assistant", msg.content});
|
||||
}
|
||||
context.history = messages;
|
||||
|
||||
LLMCore::LLMConfig config;
|
||||
config.requestType = LLMCore::RequestType::Chat;
|
||||
config.provider = provider;
|
||||
config.promptTemplate = promptTemplate;
|
||||
config.url = QString("%1%2").arg(Settings::generalSettings().caUrl(), provider->chatEndpoint());
|
||||
config.providerRequest = providerRequest;
|
||||
config.multiLineCompletion = false;
|
||||
config.apiKey = provider->apiKey();
|
||||
config.providerRequest
|
||||
= {{"model", Settings::generalSettings().caModel()},
|
||||
{"stream", chatAssistantSettings.stream()}};
|
||||
|
||||
QJsonObject request;
|
||||
request["id"] = QUuid::createUuid().toString();
|
||||
|
||||
auto errors = config.provider->validateRequest(config.providerRequest, promptTemplate->type());
|
||||
if (!errors.isEmpty()) {
|
||||
LOG_MESSAGE("Validate errors for chat request:");
|
||||
LOG_MESSAGES(errors);
|
||||
return;
|
||||
}
|
||||
config.provider
|
||||
->prepareRequest(config.providerRequest, promptTemplate, context, LLMCore::RequestType::Chat);
|
||||
|
||||
QJsonObject request{{"id", QUuid::createUuid().toString()}};
|
||||
m_requestHandler->sendLLMRequest(config, request);
|
||||
}
|
||||
|
||||
|
@ -28,7 +28,6 @@
|
||||
|
||||
#include "CodeHandler.hpp"
|
||||
#include "context/DocumentContextReader.hpp"
|
||||
#include "llmcore/MessageBuilder.hpp"
|
||||
#include "llmcore/PromptTemplateManager.hpp"
|
||||
#include "llmcore/ProvidersManager.hpp"
|
||||
#include "logger/Logger.hpp"
|
||||
@ -159,7 +158,7 @@ bool QodeAssist::LLMClientInterface::isSpecifyCompletion(const QJsonObject &requ
|
||||
|
||||
void LLMClientInterface::handleCompletion(const QJsonObject &request)
|
||||
{
|
||||
const auto updatedContext = prepareContext(request);
|
||||
auto updatedContext = prepareContext(request);
|
||||
auto &completeSettings = Settings::codeCompletionSettings();
|
||||
auto &generalSettings = Settings::generalSettings();
|
||||
|
||||
@ -196,7 +195,7 @@ void LLMClientInterface::handleCompletion(const QJsonObject &request)
|
||||
config.promptTemplate = promptTemplate;
|
||||
config.url = QUrl(QString("%1%2").arg(
|
||||
url,
|
||||
promptTemplate->type() == LLMCore::TemplateType::Fim ? provider->completionEndpoint()
|
||||
promptTemplate->type() == LLMCore::TemplateType::FIM ? provider->completionEndpoint()
|
||||
: provider->chatEndpoint()));
|
||||
config.apiKey = provider->apiKey();
|
||||
|
||||
@ -211,29 +210,30 @@ void LLMClientInterface::handleCompletion(const QJsonObject &request)
|
||||
QString systemPrompt;
|
||||
if (completeSettings.useSystemPrompt())
|
||||
systemPrompt.append(completeSettings.systemPrompt());
|
||||
if (!updatedContext.fileContext.isEmpty())
|
||||
systemPrompt.append(updatedContext.fileContext);
|
||||
if (updatedContext.fileContext.has_value())
|
||||
systemPrompt.append(updatedContext.fileContext.value());
|
||||
|
||||
QString userMessage;
|
||||
if (completeSettings.useUserMessageTemplateForCC()
|
||||
&& promptTemplate->type() == LLMCore::TemplateType::Chat) {
|
||||
userMessage
|
||||
= completeSettings.processMessageToFIM(updatedContext.prefix, updatedContext.suffix);
|
||||
} else {
|
||||
userMessage = updatedContext.prefix;
|
||||
updatedContext.systemPrompt = systemPrompt;
|
||||
|
||||
if (promptTemplate->type() == LLMCore::TemplateType::Chat) {
|
||||
QString userMessage;
|
||||
if (completeSettings.useUserMessageTemplateForCC()) {
|
||||
userMessage = completeSettings.processMessageToFIM(
|
||||
updatedContext.prefix.value_or(""), updatedContext.suffix.value_or(""));
|
||||
} else {
|
||||
userMessage = updatedContext.prefix.value_or("") + updatedContext.suffix.value_or("");
|
||||
}
|
||||
|
||||
QVector<LLMCore::Message> messages;
|
||||
messages.append({"user", userMessage});
|
||||
updatedContext.history = messages;
|
||||
}
|
||||
|
||||
auto message = LLMCore::MessageBuilder()
|
||||
.addSystemMessage(systemPrompt)
|
||||
.addUserMessage(userMessage)
|
||||
.addSuffix(updatedContext.suffix)
|
||||
.addTokenizer(promptTemplate);
|
||||
|
||||
message.saveTo(
|
||||
config.provider->prepareRequest(
|
||||
config.providerRequest,
|
||||
providerName == "Ollama" ? LLMCore::ProvidersApi::Ollama : LLMCore::ProvidersApi::OpenAI);
|
||||
|
||||
config.provider->prepareRequest(config.providerRequest, LLMCore::RequestType::CodeCompletion);
|
||||
promptTemplate,
|
||||
updatedContext,
|
||||
LLMCore::RequestType::CodeCompletion);
|
||||
|
||||
auto errors = config.provider->validateRequest(config.providerRequest, promptTemplate->type());
|
||||
if (!errors.isEmpty()) {
|
||||
|
@ -216,7 +216,7 @@ LLMCore::ContextData DocumentContextReader::prepareContext(int lineNumber, int c
|
||||
fileContext.append("\n ").append(
|
||||
ChangesManager::instance().getRecentChangesContext(m_textDocument));
|
||||
|
||||
return {contextBefore, contextAfter, fileContext};
|
||||
return {.prefix = contextBefore, .suffix = contextAfter, .fileContext = fileContext};
|
||||
}
|
||||
|
||||
QString DocumentContextReader::getContextBefore(int lineNumber, int cursorPosition) const
|
||||
|
@ -10,7 +10,6 @@ add_library(LLMCore STATIC
|
||||
OllamaMessage.hpp OllamaMessage.cpp
|
||||
OpenAIMessage.hpp OpenAIMessage.cpp
|
||||
ValidationUtils.hpp ValidationUtils.cpp
|
||||
MessageBuilder.hpp MessageBuilder.cpp
|
||||
)
|
||||
|
||||
target_link_libraries(LLMCore
|
||||
|
@ -20,14 +20,23 @@
|
||||
#pragma once
|
||||
|
||||
#include <QString>
|
||||
#include <QVector>
|
||||
|
||||
namespace QodeAssist::LLMCore {
|
||||
|
||||
struct Message
|
||||
{
|
||||
QString role;
|
||||
QString content;
|
||||
};
|
||||
|
||||
struct ContextData
|
||||
{
|
||||
QString prefix;
|
||||
QString suffix;
|
||||
QString fileContext;
|
||||
std::optional<QString> systemPrompt;
|
||||
std::optional<QString> prefix;
|
||||
std::optional<QString> suffix;
|
||||
std::optional<QString> fileContext;
|
||||
std::optional<QVector<Message>> history;
|
||||
};
|
||||
|
||||
} // namespace QodeAssist::LLMCore
|
||||
|
@ -1,93 +0,0 @@
|
||||
/*
|
||||
* Copyright (C) 2024 Petr Mironychev
|
||||
*
|
||||
* This file is part of QodeAssist.
|
||||
*
|
||||
* QodeAssist is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* QodeAssist is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with QodeAssist. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
#include "MessageBuilder.hpp"
|
||||
|
||||
QodeAssist::LLMCore::MessageBuilder &QodeAssist::LLMCore::MessageBuilder::addSystemMessage(
|
||||
const QString &content)
|
||||
{
|
||||
m_systemMessage = content;
|
||||
return *this;
|
||||
}
|
||||
|
||||
QodeAssist::LLMCore::MessageBuilder &QodeAssist::LLMCore::MessageBuilder::addUserMessage(
|
||||
const QString &content)
|
||||
{
|
||||
m_messages.append({MessageRole::User, content});
|
||||
return *this;
|
||||
}
|
||||
|
||||
QodeAssist::LLMCore::MessageBuilder &QodeAssist::LLMCore::MessageBuilder::addSuffix(
|
||||
const QString &content)
|
||||
{
|
||||
m_suffix = content;
|
||||
return *this;
|
||||
}
|
||||
|
||||
QodeAssist::LLMCore::MessageBuilder &QodeAssist::LLMCore::MessageBuilder::addTokenizer(
|
||||
PromptTemplate *promptTemplate)
|
||||
{
|
||||
m_promptTemplate = promptTemplate;
|
||||
return *this;
|
||||
}
|
||||
|
||||
QString QodeAssist::LLMCore::MessageBuilder::roleToString(MessageRole role) const
|
||||
{
|
||||
switch (role) {
|
||||
case MessageRole::System:
|
||||
return ROLE_SYSTEM;
|
||||
case MessageRole::User:
|
||||
return ROLE_USER;
|
||||
case MessageRole::Assistant:
|
||||
return ROLE_ASSISTANT;
|
||||
default:
|
||||
return ROLE_USER;
|
||||
}
|
||||
}
|
||||
|
||||
void QodeAssist::LLMCore::MessageBuilder::saveTo(QJsonObject &request, ProvidersApi api)
|
||||
{
|
||||
if (!m_promptTemplate) {
|
||||
return;
|
||||
}
|
||||
|
||||
ContextData context{
|
||||
m_messages.isEmpty() ? QString() : m_messages.last().content, m_suffix, m_systemMessage};
|
||||
|
||||
if (api == ProvidersApi::Ollama) {
|
||||
if (m_promptTemplate->type() == TemplateType::Fim) {
|
||||
request["system"] = m_systemMessage;
|
||||
m_promptTemplate->prepareRequest(request, context);
|
||||
} else {
|
||||
QJsonArray messages;
|
||||
|
||||
messages.append(QJsonObject{{"role", "system"}, {"content", m_systemMessage}});
|
||||
messages.append(QJsonObject{{"role", "user"}, {"content", m_messages.last().content}});
|
||||
request["messages"] = messages;
|
||||
m_promptTemplate->prepareRequest(request, context);
|
||||
}
|
||||
} else if (api == ProvidersApi::OpenAI) {
|
||||
QJsonArray messages;
|
||||
|
||||
messages.append(QJsonObject{{"role", "system"}, {"content", m_systemMessage}});
|
||||
messages.append(QJsonObject{{"role", "user"}, {"content", m_messages.last().content}});
|
||||
request["messages"] = messages;
|
||||
m_promptTemplate->prepareRequest(request, context);
|
||||
}
|
||||
}
|
@ -1,68 +0,0 @@
|
||||
/*
|
||||
* Copyright (C) 2024 Petr Mironychev
|
||||
*
|
||||
* This file is part of QodeAssist.
|
||||
*
|
||||
* QodeAssist is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* QodeAssist is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with QodeAssist. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <QJsonArray>
|
||||
#include <QJsonObject>
|
||||
#include <QString>
|
||||
#include <QVector>
|
||||
|
||||
#include "PromptTemplate.hpp"
|
||||
|
||||
namespace QodeAssist::LLMCore {
|
||||
|
||||
enum class MessageRole { System, User, Assistant };
|
||||
|
||||
enum class OllamaFormat { Messages, Completions };
|
||||
|
||||
enum class ProvidersApi { Ollama, OpenAI, Claude };
|
||||
|
||||
static const QString ROLE_SYSTEM = "system";
|
||||
static const QString ROLE_USER = "user";
|
||||
static const QString ROLE_ASSISTANT = "assistant";
|
||||
|
||||
struct Message
|
||||
{
|
||||
MessageRole role;
|
||||
QString content;
|
||||
};
|
||||
|
||||
class MessageBuilder
|
||||
{
|
||||
public:
|
||||
MessageBuilder &addSystemMessage(const QString &content);
|
||||
|
||||
MessageBuilder &addUserMessage(const QString &content);
|
||||
|
||||
MessageBuilder &addSuffix(const QString &content);
|
||||
|
||||
MessageBuilder &addTokenizer(PromptTemplate *promptTemplate);
|
||||
|
||||
QString roleToString(MessageRole role) const;
|
||||
|
||||
void saveTo(QJsonObject &request, ProvidersApi api);
|
||||
|
||||
private:
|
||||
QString m_systemMessage;
|
||||
QString m_suffix;
|
||||
QVector<Message> m_messages;
|
||||
PromptTemplate *m_promptTemplate;
|
||||
};
|
||||
} // namespace QodeAssist::LLMCore
|
@ -27,7 +27,7 @@
|
||||
|
||||
namespace QodeAssist::LLMCore {
|
||||
|
||||
enum class TemplateType { Chat, Fim };
|
||||
enum class TemplateType { Chat, FIM };
|
||||
|
||||
class PromptTemplate
|
||||
{
|
||||
@ -35,7 +35,7 @@ public:
|
||||
virtual ~PromptTemplate() = default;
|
||||
virtual TemplateType type() const = 0;
|
||||
virtual QString name() const = 0;
|
||||
virtual QString promptTemplate() const = 0;
|
||||
// virtual QString promptTemplate() const = 0;
|
||||
virtual QStringList stopWords() const = 0;
|
||||
virtual void prepareRequest(QJsonObject &request, const ContextData &context) const = 0;
|
||||
virtual QString description() const = 0;
|
||||
|
@ -23,6 +23,7 @@
|
||||
#include <QNetworkRequest>
|
||||
#include <QString>
|
||||
|
||||
#include "ContextData.hpp"
|
||||
#include "PromptTemplate.hpp"
|
||||
#include "RequestType.hpp"
|
||||
|
||||
@ -42,7 +43,12 @@ public:
|
||||
virtual QString chatEndpoint() const = 0;
|
||||
virtual bool supportsModelListing() const = 0;
|
||||
|
||||
virtual void prepareRequest(QJsonObject &request, RequestType type) = 0;
|
||||
virtual void prepareRequest(
|
||||
QJsonObject &request,
|
||||
LLMCore::PromptTemplate *prompt,
|
||||
LLMCore::ContextData context,
|
||||
LLMCore::RequestType type)
|
||||
= 0;
|
||||
virtual bool handleResponse(QNetworkReply *reply, QString &accumulatedResponse) = 0;
|
||||
virtual QList<QString> getInstalledModels(const QString &url) = 0;
|
||||
virtual QList<QString> validateRequest(const QJsonObject &request, TemplateType type) = 0;
|
||||
|
@ -58,7 +58,10 @@ void RequestHandler::sendLLMRequest(const LLMConfig &config, const QJsonObject &
|
||||
reply->deleteLater();
|
||||
m_activeRequests.remove(requestId);
|
||||
if (reply->error() != QNetworkReply::NoError) {
|
||||
LOG_MESSAGE(QString("Error in QodeAssist request: %1").arg(reply->errorString()));
|
||||
LOG_MESSAGE(QString("Error details: %1\nStatus code: %2\nResponse: %3")
|
||||
.arg(reply->errorString())
|
||||
.arg(reply->attribute(QNetworkRequest::HttpStatusCodeAttribute).toInt())
|
||||
.arg(QString(reply->readAll())));
|
||||
emit requestFinished(requestId, false, reply->errorString());
|
||||
} else {
|
||||
LOG_MESSAGE("Request finished successfully");
|
||||
|
@ -21,5 +21,5 @@
|
||||
|
||||
namespace QodeAssist::LLMCore {
|
||||
|
||||
enum RequestType { CodeCompletion, Chat };
|
||||
enum RequestType { CodeCompletion, Chat, Embedding };
|
||||
}
|
||||
|
@ -30,13 +30,10 @@
|
||||
#include "logger/Logger.hpp"
|
||||
#include "settings/ChatAssistantSettings.hpp"
|
||||
#include "settings/CodeCompletionSettings.hpp"
|
||||
#include "settings/GeneralSettings.hpp"
|
||||
#include "settings/ProviderSettings.hpp"
|
||||
|
||||
namespace QodeAssist::Providers {
|
||||
|
||||
ClaudeProvider::ClaudeProvider() {}
|
||||
|
||||
QString ClaudeProvider::name() const
|
||||
{
|
||||
return "Claude";
|
||||
@ -62,31 +59,17 @@ bool ClaudeProvider::supportsModelListing() const
|
||||
return true;
|
||||
}
|
||||
|
||||
void ClaudeProvider::prepareRequest(QJsonObject &request, LLMCore::RequestType type)
|
||||
void ClaudeProvider::prepareRequest(
|
||||
QJsonObject &request,
|
||||
LLMCore::PromptTemplate *prompt,
|
||||
LLMCore::ContextData context,
|
||||
LLMCore::RequestType type)
|
||||
{
|
||||
auto prepareMessages = [](QJsonObject &req) -> QJsonArray {
|
||||
QJsonArray messages;
|
||||
if (req.contains("messages")) {
|
||||
QJsonArray origMessages = req["messages"].toArray();
|
||||
for (const auto &msg : origMessages) {
|
||||
QJsonObject message = msg.toObject();
|
||||
if (message["role"].toString() == "system") {
|
||||
req["system"] = message["content"];
|
||||
} else {
|
||||
messages.append(message);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (req.contains("system")) {
|
||||
req["system"] = req["system"].toString();
|
||||
}
|
||||
if (req.contains("prompt")) {
|
||||
messages.append(
|
||||
QJsonObject{{"role", "user"}, {"content", req.take("prompt").toString()}});
|
||||
}
|
||||
}
|
||||
return messages;
|
||||
};
|
||||
// if (!isSupportedTemplate(prompt->name())) {
|
||||
// LOG_MESSAGE(QString("Provider doesn't support %1 template").arg(prompt->name()));
|
||||
// }
|
||||
|
||||
prompt->prepareRequest(request, context);
|
||||
|
||||
auto applyModelParams = [&request](const auto &settings) {
|
||||
request["max_tokens"] = settings.maxTokens();
|
||||
@ -98,11 +81,6 @@ void ClaudeProvider::prepareRequest(QJsonObject &request, LLMCore::RequestType t
|
||||
request["stream"] = true;
|
||||
};
|
||||
|
||||
QJsonArray messages = prepareMessages(request);
|
||||
if (!messages.isEmpty()) {
|
||||
request["messages"] = std::move(messages);
|
||||
}
|
||||
|
||||
if (type == LLMCore::RequestType::CodeCompletion) {
|
||||
applyModelParams(Settings::codeCompletionSettings());
|
||||
} else {
|
||||
|
@ -26,14 +26,16 @@ namespace QodeAssist::Providers {
|
||||
class ClaudeProvider : public LLMCore::Provider
|
||||
{
|
||||
public:
|
||||
ClaudeProvider();
|
||||
|
||||
QString name() const override;
|
||||
QString url() const override;
|
||||
QString completionEndpoint() const override;
|
||||
QString chatEndpoint() const override;
|
||||
bool supportsModelListing() const override;
|
||||
void prepareRequest(QJsonObject &request, LLMCore::RequestType type) override;
|
||||
void prepareRequest(
|
||||
QJsonObject &request,
|
||||
LLMCore::PromptTemplate *prompt,
|
||||
LLMCore::ContextData context,
|
||||
LLMCore::RequestType type) override;
|
||||
bool handleResponse(QNetworkReply *reply, QString &accumulatedResponse) override;
|
||||
QList<QString> getInstalledModels(const QString &url) override;
|
||||
QList<QString> validateRequest(const QJsonObject &request, LLMCore::TemplateType type) override;
|
||||
|
@ -33,8 +33,6 @@
|
||||
|
||||
namespace QodeAssist::Providers {
|
||||
|
||||
LMStudioProvider::LMStudioProvider() {}
|
||||
|
||||
QString LMStudioProvider::name() const
|
||||
{
|
||||
return "LM Studio";
|
||||
@ -60,47 +58,6 @@ bool LMStudioProvider::supportsModelListing() const
|
||||
return true;
|
||||
}
|
||||
|
||||
void LMStudioProvider::prepareRequest(QJsonObject &request, LLMCore::RequestType type)
|
||||
{
|
||||
auto prepareMessages = [](QJsonObject &req) -> QJsonArray {
|
||||
QJsonArray messages;
|
||||
if (req.contains("system")) {
|
||||
messages.append(
|
||||
QJsonObject{{"role", "system"}, {"content", req.take("system").toString()}});
|
||||
}
|
||||
if (req.contains("prompt")) {
|
||||
messages.append(
|
||||
QJsonObject{{"role", "user"}, {"content", req.take("prompt").toString()}});
|
||||
}
|
||||
return messages;
|
||||
};
|
||||
|
||||
auto applyModelParams = [&request](const auto &settings) {
|
||||
request["max_tokens"] = settings.maxTokens();
|
||||
request["temperature"] = settings.temperature();
|
||||
|
||||
if (settings.useTopP())
|
||||
request["top_p"] = settings.topP();
|
||||
if (settings.useTopK())
|
||||
request["top_k"] = settings.topK();
|
||||
if (settings.useFrequencyPenalty())
|
||||
request["frequency_penalty"] = settings.frequencyPenalty();
|
||||
if (settings.usePresencePenalty())
|
||||
request["presence_penalty"] = settings.presencePenalty();
|
||||
};
|
||||
|
||||
QJsonArray messages = prepareMessages(request);
|
||||
if (!messages.isEmpty()) {
|
||||
request["messages"] = std::move(messages);
|
||||
}
|
||||
|
||||
if (type == LLMCore::RequestType::CodeCompletion) {
|
||||
applyModelParams(Settings::codeCompletionSettings());
|
||||
} else {
|
||||
applyModelParams(Settings::chatAssistantSettings());
|
||||
}
|
||||
}
|
||||
|
||||
bool LMStudioProvider::handleResponse(QNetworkReply *reply, QString &accumulatedResponse)
|
||||
{
|
||||
QByteArray data = reply->readAll();
|
||||
@ -211,4 +168,37 @@ void LMStudioProvider::prepareNetworkRequest(QNetworkRequest &networkRequest) co
|
||||
networkRequest.setHeader(QNetworkRequest::ContentTypeHeader, "application/json");
|
||||
}
|
||||
|
||||
void QodeAssist::Providers::LMStudioProvider::prepareRequest(
|
||||
QJsonObject &request,
|
||||
LLMCore::PromptTemplate *prompt,
|
||||
LLMCore::ContextData context,
|
||||
LLMCore::RequestType type)
|
||||
{
|
||||
// if (!isSupportedTemplate(prompt->name())) {
|
||||
// LOG_MESSAGE(QString("Provider doesn't support %1 template").arg(prompt->name()));
|
||||
// }
|
||||
|
||||
prompt->prepareRequest(request, context);
|
||||
|
||||
auto applyModelParams = [&request](const auto &settings) {
|
||||
request["max_tokens"] = settings.maxTokens();
|
||||
request["temperature"] = settings.temperature();
|
||||
|
||||
if (settings.useTopP())
|
||||
request["top_p"] = settings.topP();
|
||||
if (settings.useTopK())
|
||||
request["top_k"] = settings.topK();
|
||||
if (settings.useFrequencyPenalty())
|
||||
request["frequency_penalty"] = settings.frequencyPenalty();
|
||||
if (settings.usePresencePenalty())
|
||||
request["presence_penalty"] = settings.presencePenalty();
|
||||
};
|
||||
|
||||
if (type == LLMCore::RequestType::CodeCompletion) {
|
||||
applyModelParams(Settings::codeCompletionSettings());
|
||||
} else {
|
||||
applyModelParams(Settings::chatAssistantSettings());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace QodeAssist::Providers
|
||||
|
@ -26,14 +26,16 @@ namespace QodeAssist::Providers {
|
||||
class LMStudioProvider : public LLMCore::Provider
|
||||
{
|
||||
public:
|
||||
LMStudioProvider();
|
||||
|
||||
QString name() const override;
|
||||
QString url() const override;
|
||||
QString completionEndpoint() const override;
|
||||
QString chatEndpoint() const override;
|
||||
bool supportsModelListing() const override;
|
||||
void prepareRequest(QJsonObject &request, LLMCore::RequestType type) override;
|
||||
void prepareRequest(
|
||||
QJsonObject &request,
|
||||
LLMCore::PromptTemplate *prompt,
|
||||
LLMCore::ContextData context,
|
||||
LLMCore::RequestType type) override;
|
||||
bool handleResponse(QNetworkReply *reply, QString &accumulatedResponse) override;
|
||||
QList<QString> getInstalledModels(const QString &url) override;
|
||||
QList<QString> validateRequest(const QJsonObject &request, LLMCore::TemplateType type) override;
|
||||
|
215
providers/MistralAIProvider.cpp
Normal file
215
providers/MistralAIProvider.cpp
Normal file
@ -0,0 +1,215 @@
|
||||
#include "MistralAIProvider.hpp"
|
||||
|
||||
#include "settings/ChatAssistantSettings.hpp"
|
||||
#include "settings/CodeCompletionSettings.hpp"
|
||||
#include "settings/ProviderSettings.hpp"
|
||||
|
||||
#include <QJsonArray>
|
||||
#include <QJsonDocument>
|
||||
#include <QJsonObject>
|
||||
#include <QNetworkReply>
|
||||
#include <QtCore/qeventloop.h>
|
||||
|
||||
#include "llmcore/OpenAIMessage.hpp"
|
||||
#include "llmcore/ValidationUtils.hpp"
|
||||
#include "logger/Logger.hpp"
|
||||
|
||||
namespace QodeAssist::Providers {
|
||||
|
||||
QString MistralAIProvider::name() const
|
||||
{
|
||||
return "Mistral AI";
|
||||
}
|
||||
|
||||
QString MistralAIProvider::url() const
|
||||
{
|
||||
return "https://api.mistral.ai";
|
||||
}
|
||||
|
||||
QString MistralAIProvider::completionEndpoint() const
|
||||
{
|
||||
return "/v1/fim/completions";
|
||||
}
|
||||
|
||||
QString MistralAIProvider::chatEndpoint() const
|
||||
{
|
||||
return "/v1/chat/completions";
|
||||
}
|
||||
|
||||
bool MistralAIProvider::supportsModelListing() const
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
bool MistralAIProvider::handleResponse(QNetworkReply *reply, QString &accumulatedResponse)
|
||||
{
|
||||
QByteArray data = reply->readAll();
|
||||
if (data.isEmpty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool isDone = false;
|
||||
QByteArrayList lines = data.split('\n');
|
||||
|
||||
for (const QByteArray &line : lines) {
|
||||
if (line.trimmed().isEmpty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (line == "data: [DONE]") {
|
||||
isDone = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
QByteArray jsonData = line;
|
||||
if (line.startsWith("data: ")) {
|
||||
jsonData = line.mid(6);
|
||||
}
|
||||
|
||||
QJsonParseError error;
|
||||
QJsonDocument doc = QJsonDocument::fromJson(jsonData, &error);
|
||||
|
||||
if (doc.isNull()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto message = LLMCore::OpenAIMessage::fromJson(doc.object());
|
||||
if (message.hasError()) {
|
||||
LOG_MESSAGE("Error in OpenAI response: " + message.error);
|
||||
continue;
|
||||
}
|
||||
|
||||
QString content = message.getContent();
|
||||
if (!content.isEmpty()) {
|
||||
accumulatedResponse += content;
|
||||
}
|
||||
|
||||
if (message.isDone()) {
|
||||
isDone = true;
|
||||
}
|
||||
}
|
||||
|
||||
return isDone;
|
||||
}
|
||||
|
||||
QList<QString> MistralAIProvider::getInstalledModels(const QString &url)
|
||||
{
|
||||
QList<QString> models;
|
||||
QNetworkAccessManager manager;
|
||||
QNetworkRequest request(QString("%1/v1/models").arg(url));
|
||||
|
||||
request.setHeader(QNetworkRequest::ContentTypeHeader, "application/json");
|
||||
if (!apiKey().isEmpty()) {
|
||||
request.setRawHeader("Authorization", QString("Bearer %1").arg(apiKey()).toUtf8());
|
||||
}
|
||||
|
||||
QNetworkReply *reply = manager.get(request);
|
||||
QEventLoop loop;
|
||||
QObject::connect(reply, &QNetworkReply::finished, &loop, &QEventLoop::quit);
|
||||
loop.exec();
|
||||
|
||||
if (reply->error() == QNetworkReply::NoError) {
|
||||
QByteArray responseData = reply->readAll();
|
||||
QJsonDocument jsonResponse = QJsonDocument::fromJson(responseData);
|
||||
QJsonObject jsonObject = jsonResponse.object();
|
||||
|
||||
if (jsonObject.contains("data") && jsonObject["object"].toString() == "list") {
|
||||
QJsonArray modelArray = jsonObject["data"].toArray();
|
||||
for (const QJsonValue &value : modelArray) {
|
||||
QJsonObject modelObject = value.toObject();
|
||||
if (modelObject.contains("id")) {
|
||||
QString modelId = modelObject["id"].toString();
|
||||
models.append(modelId);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
LOG_MESSAGE(QString("Error fetching Mistral AI models: %1").arg(reply->errorString()));
|
||||
}
|
||||
|
||||
reply->deleteLater();
|
||||
return models;
|
||||
}
|
||||
|
||||
QList<QString> MistralAIProvider::validateRequest(
|
||||
const QJsonObject &request, LLMCore::TemplateType type)
|
||||
{
|
||||
const auto fimReq = QJsonObject{
|
||||
{"model", {}},
|
||||
{"max_tokens", {}},
|
||||
{"stream", {}},
|
||||
{"temperature", {}},
|
||||
{"prompt", {}},
|
||||
{"suffix", {}}};
|
||||
|
||||
const auto templateReq = QJsonObject{
|
||||
{"model", {}},
|
||||
{"messages", QJsonArray{{QJsonObject{{"role", {}}, {"content", {}}}}}},
|
||||
{"temperature", {}},
|
||||
{"max_tokens", {}},
|
||||
{"top_p", {}},
|
||||
{"frequency_penalty", {}},
|
||||
{"presence_penalty", {}},
|
||||
{"stop", QJsonArray{}},
|
||||
{"stream", {}}};
|
||||
|
||||
return LLMCore::ValidationUtils::validateRequestFields(
|
||||
request, type == LLMCore::TemplateType::FIM ? fimReq : templateReq);
|
||||
}
|
||||
|
||||
QString MistralAIProvider::apiKey() const
|
||||
{
|
||||
return Settings::providerSettings().mistralAiApiKey();
|
||||
}
|
||||
|
||||
void MistralAIProvider::prepareNetworkRequest(QNetworkRequest &networkRequest) const
|
||||
{
|
||||
networkRequest.setHeader(QNetworkRequest::ContentTypeHeader, "application/json");
|
||||
|
||||
if (!apiKey().isEmpty()) {
|
||||
networkRequest.setRawHeader("Authorization", QString("Bearer %1").arg(apiKey()).toUtf8());
|
||||
}
|
||||
}
|
||||
|
||||
void MistralAIProvider::prepareRequest(
|
||||
QJsonObject &request,
|
||||
LLMCore::PromptTemplate *prompt,
|
||||
LLMCore::ContextData context,
|
||||
LLMCore::RequestType type)
|
||||
{
|
||||
// if (!isSupportedTemplate(prompt->name())) {
|
||||
// LOG_MESSAGE(QString("Provider doesn't support %1 template").arg(prompt->name()));
|
||||
// }
|
||||
|
||||
prompt->prepareRequest(request, context);
|
||||
|
||||
if (type == LLMCore::RequestType::Chat) {
|
||||
auto &settings = Settings::chatAssistantSettings();
|
||||
|
||||
request["max_tokens"] = settings.maxTokens();
|
||||
request["temperature"] = settings.temperature();
|
||||
|
||||
if (settings.useTopP())
|
||||
request["top_p"] = settings.topP();
|
||||
|
||||
// request["random_seed"] = "";
|
||||
|
||||
if (settings.useFrequencyPenalty())
|
||||
request["frequency_penalty"] = settings.frequencyPenalty();
|
||||
if (settings.usePresencePenalty())
|
||||
request["presence_penalty"] = settings.presencePenalty();
|
||||
|
||||
} else {
|
||||
auto &settings = Settings::codeCompletionSettings();
|
||||
|
||||
request["max_tokens"] = settings.maxTokens();
|
||||
request["temperature"] = settings.temperature();
|
||||
|
||||
if (settings.useTopP())
|
||||
request["top_p"] = settings.topP();
|
||||
|
||||
// request["random_seed"] = "";
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace QodeAssist::Providers
|
46
providers/MistralAIProvider.hpp
Normal file
46
providers/MistralAIProvider.hpp
Normal file
@ -0,0 +1,46 @@
|
||||
/*
|
||||
* Copyright (C) 2024 Petr Mironychev
|
||||
*
|
||||
* This file is part of QodeAssist.
|
||||
*
|
||||
* QodeAssist is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* QodeAssist is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with QodeAssist. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "llmcore/Provider.hpp"
|
||||
|
||||
namespace QodeAssist::Providers {
|
||||
|
||||
class MistralAIProvider : public LLMCore::Provider
|
||||
{
|
||||
public:
|
||||
QString name() const override;
|
||||
QString url() const override;
|
||||
QString completionEndpoint() const override;
|
||||
QString chatEndpoint() const override;
|
||||
bool supportsModelListing() const override;
|
||||
void prepareRequest(
|
||||
QJsonObject &request,
|
||||
LLMCore::PromptTemplate *prompt,
|
||||
LLMCore::ContextData context,
|
||||
LLMCore::RequestType type) override;
|
||||
bool handleResponse(QNetworkReply *reply, QString &accumulatedResponse) override;
|
||||
QList<QString> getInstalledModels(const QString &url) override;
|
||||
QList<QString> validateRequest(const QJsonObject &request, LLMCore::TemplateType type) override;
|
||||
QString apiKey() const override;
|
||||
void prepareNetworkRequest(QNetworkRequest &networkRequest) const override;
|
||||
};
|
||||
|
||||
} // namespace QodeAssist::Providers
|
@ -33,8 +33,6 @@
|
||||
|
||||
namespace QodeAssist::Providers {
|
||||
|
||||
OllamaProvider::OllamaProvider() {}
|
||||
|
||||
QString OllamaProvider::name() const
|
||||
{
|
||||
return "Ollama";
|
||||
@ -60,8 +58,18 @@ bool OllamaProvider::supportsModelListing() const
|
||||
return true;
|
||||
}
|
||||
|
||||
void OllamaProvider::prepareRequest(QJsonObject &request, LLMCore::RequestType type)
|
||||
void OllamaProvider::prepareRequest(
|
||||
QJsonObject &request,
|
||||
LLMCore::PromptTemplate *prompt,
|
||||
LLMCore::ContextData context,
|
||||
LLMCore::RequestType type)
|
||||
{
|
||||
// if (!isSupportedTemplate(prompt->name())) {
|
||||
// LOG_MESSAGE(QString("Provider doesn't support %1 template").arg(prompt->name()));
|
||||
// }
|
||||
|
||||
prompt->prepareRequest(request, context);
|
||||
|
||||
auto applySettings = [&request](const auto &settings) {
|
||||
QJsonObject options;
|
||||
options["num_predict"] = settings.maxTokens();
|
||||
@ -192,7 +200,7 @@ QList<QString> OllamaProvider::validateRequest(const QJsonObject &request, LLMCo
|
||||
{"presence_penalty", {}}}}};
|
||||
|
||||
return LLMCore::ValidationUtils::validateRequestFields(
|
||||
request, type == LLMCore::TemplateType::Fim ? fimReq : messageReq);
|
||||
request, type == LLMCore::TemplateType::FIM ? fimReq : messageReq);
|
||||
}
|
||||
|
||||
QString OllamaProvider::apiKey() const
|
||||
@ -203,6 +211,6 @@ QString OllamaProvider::apiKey() const
|
||||
void OllamaProvider::prepareNetworkRequest(QNetworkRequest &networkRequest) const
|
||||
{
|
||||
networkRequest.setHeader(QNetworkRequest::ContentTypeHeader, "application/json");
|
||||
};
|
||||
}
|
||||
|
||||
} // namespace QodeAssist::Providers
|
||||
|
@ -26,14 +26,16 @@ namespace QodeAssist::Providers {
|
||||
class OllamaProvider : public LLMCore::Provider
|
||||
{
|
||||
public:
|
||||
OllamaProvider();
|
||||
|
||||
QString name() const override;
|
||||
QString url() const override;
|
||||
QString completionEndpoint() const override;
|
||||
QString chatEndpoint() const override;
|
||||
bool supportsModelListing() const override;
|
||||
void prepareRequest(QJsonObject &request, LLMCore::RequestType type) override;
|
||||
void prepareRequest(
|
||||
QJsonObject &request,
|
||||
LLMCore::PromptTemplate *prompt,
|
||||
LLMCore::ContextData context,
|
||||
LLMCore::RequestType type) override;
|
||||
bool handleResponse(QNetworkReply *reply, QString &accumulatedResponse) override;
|
||||
QList<QString> getInstalledModels(const QString &url) override;
|
||||
QList<QString> validateRequest(const QJsonObject &request, LLMCore::TemplateType type) override;
|
||||
|
@ -34,8 +34,6 @@
|
||||
|
||||
namespace QodeAssist::Providers {
|
||||
|
||||
OpenAICompatProvider::OpenAICompatProvider() {}
|
||||
|
||||
QString OpenAICompatProvider::name() const
|
||||
{
|
||||
return "OpenAI Compatible";
|
||||
@ -61,20 +59,17 @@ bool OpenAICompatProvider::supportsModelListing() const
|
||||
return false;
|
||||
}
|
||||
|
||||
void OpenAICompatProvider::prepareRequest(QJsonObject &request, LLMCore::RequestType type)
|
||||
void OpenAICompatProvider::prepareRequest(
|
||||
QJsonObject &request,
|
||||
LLMCore::PromptTemplate *prompt,
|
||||
LLMCore::ContextData context,
|
||||
LLMCore::RequestType type)
|
||||
{
|
||||
auto prepareMessages = [](QJsonObject &req) -> QJsonArray {
|
||||
QJsonArray messages;
|
||||
if (req.contains("system")) {
|
||||
messages.append(
|
||||
QJsonObject{{"role", "system"}, {"content", req.take("system").toString()}});
|
||||
}
|
||||
if (req.contains("prompt")) {
|
||||
messages.append(
|
||||
QJsonObject{{"role", "user"}, {"content", req.take("prompt").toString()}});
|
||||
}
|
||||
return messages;
|
||||
};
|
||||
// if (!isSupportedTemplate(prompt->name())) {
|
||||
// LOG_MESSAGE(QString("Provider doesn't support %1 template").arg(prompt->name()));
|
||||
// }
|
||||
|
||||
prompt->prepareRequest(request, context);
|
||||
|
||||
auto applyModelParams = [&request](const auto &settings) {
|
||||
request["max_tokens"] = settings.maxTokens();
|
||||
@ -90,11 +85,6 @@ void OpenAICompatProvider::prepareRequest(QJsonObject &request, LLMCore::Request
|
||||
request["presence_penalty"] = settings.presencePenalty();
|
||||
};
|
||||
|
||||
QJsonArray messages = prepareMessages(request);
|
||||
if (!messages.isEmpty()) {
|
||||
request["messages"] = std::move(messages);
|
||||
}
|
||||
|
||||
if (type == LLMCore::RequestType::CodeCompletion) {
|
||||
applyModelParams(Settings::codeCompletionSettings());
|
||||
} else {
|
||||
|
@ -26,14 +26,16 @@ namespace QodeAssist::Providers {
|
||||
class OpenAICompatProvider : public LLMCore::Provider
|
||||
{
|
||||
public:
|
||||
OpenAICompatProvider();
|
||||
|
||||
QString name() const override;
|
||||
QString url() const override;
|
||||
QString completionEndpoint() const override;
|
||||
QString chatEndpoint() const override;
|
||||
bool supportsModelListing() const override;
|
||||
void prepareRequest(QJsonObject &request, LLMCore::RequestType type) override;
|
||||
void prepareRequest(
|
||||
QJsonObject &request,
|
||||
LLMCore::PromptTemplate *prompt,
|
||||
LLMCore::ContextData context,
|
||||
LLMCore::RequestType type) override;
|
||||
bool handleResponse(QNetworkReply *reply, QString &accumulatedResponse) override;
|
||||
QList<QString> getInstalledModels(const QString &url) override;
|
||||
QList<QString> validateRequest(const QJsonObject &request, LLMCore::TemplateType type) override;
|
||||
|
@ -35,8 +35,6 @@
|
||||
|
||||
namespace QodeAssist::Providers {
|
||||
|
||||
OpenAIProvider::OpenAIProvider() {}
|
||||
|
||||
QString OpenAIProvider::name() const
|
||||
{
|
||||
return "OpenAI";
|
||||
@ -62,20 +60,17 @@ bool OpenAIProvider::supportsModelListing() const
|
||||
return true;
|
||||
}
|
||||
|
||||
void OpenAIProvider::prepareRequest(QJsonObject &request, LLMCore::RequestType type)
|
||||
void OpenAIProvider::prepareRequest(
|
||||
QJsonObject &request,
|
||||
LLMCore::PromptTemplate *prompt,
|
||||
LLMCore::ContextData context,
|
||||
LLMCore::RequestType type)
|
||||
{
|
||||
auto prepareMessages = [](QJsonObject &req) -> QJsonArray {
|
||||
QJsonArray messages;
|
||||
if (req.contains("system")) {
|
||||
messages.append(
|
||||
QJsonObject{{"role", "system"}, {"content", req.take("system").toString()}});
|
||||
}
|
||||
if (req.contains("prompt")) {
|
||||
messages.append(
|
||||
QJsonObject{{"role", "user"}, {"content", req.take("prompt").toString()}});
|
||||
}
|
||||
return messages;
|
||||
};
|
||||
// if (!isSupportedTemplate(prompt->name())) {
|
||||
// LOG_MESSAGE(QString("Provider doesn't support %1 template").arg(prompt->name()));
|
||||
// }
|
||||
|
||||
prompt->prepareRequest(request, context);
|
||||
|
||||
auto applyModelParams = [&request](const auto &settings) {
|
||||
request["max_tokens"] = settings.maxTokens();
|
||||
@ -91,11 +86,6 @@ void OpenAIProvider::prepareRequest(QJsonObject &request, LLMCore::RequestType t
|
||||
request["presence_penalty"] = settings.presencePenalty();
|
||||
};
|
||||
|
||||
QJsonArray messages = prepareMessages(request);
|
||||
if (!messages.isEmpty()) {
|
||||
request["messages"] = std::move(messages);
|
||||
}
|
||||
|
||||
if (type == LLMCore::RequestType::CodeCompletion) {
|
||||
applyModelParams(Settings::codeCompletionSettings());
|
||||
} else {
|
||||
|
@ -26,14 +26,16 @@ namespace QodeAssist::Providers {
|
||||
class OpenAIProvider : public LLMCore::Provider
|
||||
{
|
||||
public:
|
||||
OpenAIProvider();
|
||||
|
||||
QString name() const override;
|
||||
QString url() const override;
|
||||
QString completionEndpoint() const override;
|
||||
QString chatEndpoint() const override;
|
||||
bool supportsModelListing() const override;
|
||||
void prepareRequest(QJsonObject &request, LLMCore::RequestType type) override;
|
||||
void prepareRequest(
|
||||
QJsonObject &request,
|
||||
LLMCore::PromptTemplate *prompt,
|
||||
LLMCore::ContextData context,
|
||||
LLMCore::RequestType type) override;
|
||||
bool handleResponse(QNetworkReply *reply, QString &accumulatedResponse) override;
|
||||
QList<QString> getInstalledModels(const QString &url) override;
|
||||
QList<QString> validateRequest(const QJsonObject &request, LLMCore::TemplateType type) override;
|
||||
|
@ -33,8 +33,6 @@
|
||||
|
||||
namespace QodeAssist::Providers {
|
||||
|
||||
OpenRouterProvider::OpenRouterProvider() {}
|
||||
|
||||
QString OpenRouterProvider::name() const
|
||||
{
|
||||
return "OpenRouter";
|
||||
@ -45,47 +43,6 @@ QString OpenRouterProvider::url() const
|
||||
return "https://openrouter.ai/api";
|
||||
}
|
||||
|
||||
void OpenRouterProvider::prepareRequest(QJsonObject &request, LLMCore::RequestType type)
|
||||
{
|
||||
auto prepareMessages = [](QJsonObject &req) -> QJsonArray {
|
||||
QJsonArray messages;
|
||||
if (req.contains("system")) {
|
||||
messages.append(
|
||||
QJsonObject{{"role", "system"}, {"content", req.take("system").toString()}});
|
||||
}
|
||||
if (req.contains("prompt")) {
|
||||
messages.append(
|
||||
QJsonObject{{"role", "user"}, {"content", req.take("prompt").toString()}});
|
||||
}
|
||||
return messages;
|
||||
};
|
||||
|
||||
auto applyModelParams = [&request](const auto &settings) {
|
||||
request["max_tokens"] = settings.maxTokens();
|
||||
request["temperature"] = settings.temperature();
|
||||
|
||||
if (settings.useTopP())
|
||||
request["top_p"] = settings.topP();
|
||||
if (settings.useTopK())
|
||||
request["top_k"] = settings.topK();
|
||||
if (settings.useFrequencyPenalty())
|
||||
request["frequency_penalty"] = settings.frequencyPenalty();
|
||||
if (settings.usePresencePenalty())
|
||||
request["presence_penalty"] = settings.presencePenalty();
|
||||
};
|
||||
|
||||
QJsonArray messages = prepareMessages(request);
|
||||
if (!messages.isEmpty()) {
|
||||
request["messages"] = std::move(messages);
|
||||
}
|
||||
|
||||
if (type == LLMCore::RequestType::CodeCompletion) {
|
||||
applyModelParams(Settings::codeCompletionSettings());
|
||||
} else {
|
||||
applyModelParams(Settings::chatAssistantSettings());
|
||||
}
|
||||
}
|
||||
|
||||
bool OpenRouterProvider::handleResponse(QNetworkReply *reply, QString &accumulatedResponse)
|
||||
{
|
||||
QByteArray data = reply->readAll();
|
||||
|
@ -27,11 +27,8 @@ namespace QodeAssist::Providers {
|
||||
class OpenRouterProvider : public OpenAICompatProvider
|
||||
{
|
||||
public:
|
||||
OpenRouterProvider();
|
||||
|
||||
QString name() const override;
|
||||
QString url() const override;
|
||||
void prepareRequest(QJsonObject &request, LLMCore::RequestType type) override;
|
||||
bool handleResponse(QNetworkReply *reply, QString &accumulatedResponse) override;
|
||||
QString apiKey() const override;
|
||||
};
|
||||
|
@ -22,6 +22,7 @@
|
||||
#include "llmcore/ProvidersManager.hpp"
|
||||
#include "providers/ClaudeProvider.hpp"
|
||||
#include "providers/LMStudioProvider.hpp"
|
||||
#include "providers/MistralAIProvider.hpp"
|
||||
#include "providers/OllamaProvider.hpp"
|
||||
#include "providers/OpenAICompatProvider.hpp"
|
||||
#include "providers/OpenAIProvider.hpp"
|
||||
@ -33,11 +34,12 @@ inline void registerProviders()
|
||||
{
|
||||
auto &providerManager = LLMCore::ProvidersManager::instance();
|
||||
providerManager.registerProvider<OllamaProvider>();
|
||||
providerManager.registerProvider<LMStudioProvider>();
|
||||
providerManager.registerProvider<OpenAICompatProvider>();
|
||||
providerManager.registerProvider<OpenRouterProvider>();
|
||||
providerManager.registerProvider<ClaudeProvider>();
|
||||
providerManager.registerProvider<OpenAIProvider>();
|
||||
providerManager.registerProvider<OpenAICompatProvider>();
|
||||
providerManager.registerProvider<LMStudioProvider>();
|
||||
providerManager.registerProvider<OpenRouterProvider>();
|
||||
providerManager.registerProvider<MistralAIProvider>();
|
||||
}
|
||||
|
||||
} // namespace QodeAssist::Providers
|
||||
|
@ -69,6 +69,7 @@ ProviderSettings::ProviderSettings()
|
||||
claudeApiKey.setDefaultValue("");
|
||||
claudeApiKey.setAutoApply(true);
|
||||
|
||||
// OpenAI Settings
|
||||
openAiApiKey.setSettingsKey(Constants::OPEN_AI_API_KEY);
|
||||
openAiApiKey.setLabelText(Tr::tr("OpenAI API Key:"));
|
||||
openAiApiKey.setDisplayStyle(Utils::StringAspect::LineEditDisplay);
|
||||
@ -77,6 +78,15 @@ ProviderSettings::ProviderSettings()
|
||||
openAiApiKey.setDefaultValue("");
|
||||
openAiApiKey.setAutoApply(true);
|
||||
|
||||
// MistralAI Settings
|
||||
mistralAiApiKey.setSettingsKey(Constants::MISTRAL_AI_API_KEY);
|
||||
mistralAiApiKey.setLabelText(Tr::tr("Mistral AI API Key:"));
|
||||
mistralAiApiKey.setDisplayStyle(Utils::StringAspect::LineEditDisplay);
|
||||
mistralAiApiKey.setPlaceHolderText(Tr::tr("Enter your API key here"));
|
||||
mistralAiApiKey.setHistoryCompleter(Constants::MISTRAL_AI_API_KEY_HISTORY);
|
||||
mistralAiApiKey.setDefaultValue("");
|
||||
mistralAiApiKey.setAutoApply(true);
|
||||
|
||||
resetToDefaults.m_buttonText = Tr::tr("Reset Page to Defaults");
|
||||
|
||||
readSettings();
|
||||
@ -96,6 +106,8 @@ ProviderSettings::ProviderSettings()
|
||||
Group{title(Tr::tr("OpenAI Compatible Settings")), Column{openAiCompatApiKey}},
|
||||
Space{8},
|
||||
Group{title(Tr::tr("Claude Settings")), Column{claudeApiKey}},
|
||||
Space{8},
|
||||
Group{title(Tr::tr("Mistral AI Settings")), Column{mistralAiApiKey}},
|
||||
Stretch{1}};
|
||||
});
|
||||
}
|
||||
|
@ -37,6 +37,7 @@ public:
|
||||
Utils::StringAspect openAiCompatApiKey{this};
|
||||
Utils::StringAspect claudeApiKey{this};
|
||||
Utils::StringAspect openAiApiKey{this};
|
||||
Utils::StringAspect mistralAiApiKey{this};
|
||||
|
||||
private:
|
||||
void setupConnections();
|
||||
|
@ -100,6 +100,8 @@ const char CLAUDE_API_KEY[] = "QodeAssist.claudeApiKey";
|
||||
const char CLAUDE_API_KEY_HISTORY[] = "QodeAssist.claudeApiKeyHistory";
|
||||
const char OPEN_AI_API_KEY[] = "QodeAssist.openAiApiKey";
|
||||
const char OPEN_AI_API_KEY_HISTORY[] = "QodeAssist.openAiApiKeyHistory";
|
||||
const char MISTRAL_AI_API_KEY[] = "QodeAssist.mistralAiApiKey";
|
||||
const char MISTRAL_AI_API_KEY_HISTORY[] = "QodeAssist.mistralAiApiKeyHistory";
|
||||
|
||||
// context settings
|
||||
const char CC_READ_FULL_FILE[] = "QodeAssist.ccReadFullFile";
|
||||
|
@ -29,33 +29,32 @@ class Alpaca : public LLMCore::PromptTemplate
|
||||
public:
|
||||
QString name() const override { return "Alpaca"; }
|
||||
LLMCore::TemplateType type() const override { return LLMCore::TemplateType::Chat; }
|
||||
QString promptTemplate() const override { return {}; }
|
||||
QStringList stopWords() const override
|
||||
{
|
||||
return QStringList() << "### Instruction:" << "### Response:";
|
||||
}
|
||||
void prepareRequest(QJsonObject &request, const LLMCore::ContextData &context) const override
|
||||
{
|
||||
QJsonArray messages = request["messages"].toArray();
|
||||
QJsonArray messages;
|
||||
|
||||
for (int i = 0; i < messages.size(); ++i) {
|
||||
QJsonObject message = messages[i].toObject();
|
||||
QString role = message["role"].toString();
|
||||
QString content = message["content"].toString();
|
||||
QString fullContent;
|
||||
|
||||
QString formattedContent;
|
||||
if (role == "system") {
|
||||
formattedContent = content + "\n\n";
|
||||
} else if (role == "user") {
|
||||
formattedContent = "### Instruction:\n" + content + "\n\n";
|
||||
} else if (role == "assistant") {
|
||||
formattedContent = "### Response:\n" + content + "\n\n";
|
||||
}
|
||||
|
||||
message["content"] = formattedContent;
|
||||
messages[i] = message;
|
||||
if (context.systemPrompt) {
|
||||
fullContent += context.systemPrompt.value() + "\n\n";
|
||||
}
|
||||
|
||||
if (context.history) {
|
||||
for (const auto &msg : context.history.value()) {
|
||||
if (msg.role == "user") {
|
||||
fullContent += QString("### Instruction:\n%1\n\n").arg(msg.content);
|
||||
} else if (msg.role == "assistant") {
|
||||
fullContent += QString("### Response:\n%1\n\n").arg(msg.content);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
messages.append(QJsonObject{{"role", "user"}, {"content", fullContent}});
|
||||
|
||||
request["messages"] = messages;
|
||||
}
|
||||
QString description() const override
|
||||
|
@ -30,23 +30,28 @@ class ChatML : public LLMCore::PromptTemplate
|
||||
public:
|
||||
QString name() const override { return "ChatML"; }
|
||||
LLMCore::TemplateType type() const override { return LLMCore::TemplateType::Chat; }
|
||||
QString promptTemplate() const override { return {}; }
|
||||
QStringList stopWords() const override
|
||||
{
|
||||
return QStringList() << "<|im_start|>" << "<|im_end|>";
|
||||
}
|
||||
void prepareRequest(QJsonObject &request, const LLMCore::ContextData &context) const override
|
||||
{
|
||||
QJsonArray messages = request["messages"].toArray();
|
||||
QJsonArray messages;
|
||||
|
||||
for (int i = 0; i < messages.size(); ++i) {
|
||||
QJsonObject message = messages[i].toObject();
|
||||
QString role = message["role"].toString();
|
||||
QString content = message["content"].toString();
|
||||
if (context.systemPrompt) {
|
||||
messages.append(QJsonObject{
|
||||
{"role", "system"},
|
||||
{"content",
|
||||
QString("<|im_start|>system\n%2\n<|im_end|>").arg(context.systemPrompt.value())}});
|
||||
}
|
||||
|
||||
message["content"] = QString("<|im_start|>%1\n%2\n<|im_end|>").arg(role, content);
|
||||
|
||||
messages[i] = message;
|
||||
if (context.history) {
|
||||
for (const auto &msg : context.history.value()) {
|
||||
messages.append(QJsonObject{
|
||||
{"role", msg.role},
|
||||
{"content",
|
||||
QString("<|im_start|>%1\n%2\n<|im_end|>").arg(msg.role, msg.content)}});
|
||||
}
|
||||
}
|
||||
|
||||
request["messages"] = messages;
|
||||
|
@ -30,9 +30,25 @@ class Claude : public LLMCore::PromptTemplate
|
||||
public:
|
||||
LLMCore::TemplateType type() const override { return LLMCore::TemplateType::Chat; }
|
||||
QString name() const override { return "Claude"; }
|
||||
QString promptTemplate() const override { return {}; }
|
||||
QStringList stopWords() const override { return QStringList(); }
|
||||
void prepareRequest(QJsonObject &request, const LLMCore::ContextData &context) const override {}
|
||||
void prepareRequest(QJsonObject &request, const LLMCore::ContextData &context) const override
|
||||
{
|
||||
QJsonArray messages;
|
||||
|
||||
if (context.systemPrompt) {
|
||||
request["system"] = context.systemPrompt.value();
|
||||
}
|
||||
|
||||
if (context.history) {
|
||||
for (const auto &msg : context.history.value()) {
|
||||
if (msg.role != "system") {
|
||||
messages.append(QJsonObject{{"role", msg.role}, {"content", msg.content}});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
request["messages"] = messages;
|
||||
}
|
||||
QString description() const override { return "Claude"; }
|
||||
};
|
||||
|
||||
|
@ -26,17 +26,17 @@ namespace QodeAssist::Templates {
|
||||
class CodeLlamaFim : public LLMCore::PromptTemplate
|
||||
{
|
||||
public:
|
||||
LLMCore::TemplateType type() const override { return LLMCore::TemplateType::Fim; }
|
||||
LLMCore::TemplateType type() const override { return LLMCore::TemplateType::FIM; }
|
||||
QString name() const override { return "CodeLlama FIM"; }
|
||||
QString promptTemplate() const override { return "<PRE> %1 <SUF>%2 <MID>"; }
|
||||
QStringList stopWords() const override
|
||||
{
|
||||
return QStringList() << "<EOT>" << "<PRE>" << "<SUF" << "<MID>";
|
||||
}
|
||||
void prepareRequest(QJsonObject &request, const LLMCore::ContextData &context) const override
|
||||
{
|
||||
QString formattedPrompt = promptTemplate().arg(context.prefix, context.suffix);
|
||||
request["prompt"] = formattedPrompt;
|
||||
request["prompt"] = QString("<PRE> %1 <SUF>%2 <MID>")
|
||||
.arg(context.prefix.value_or(""), context.suffix.value_or(""));
|
||||
request["system"] = context.systemPrompt.value_or("");
|
||||
}
|
||||
QString description() const override
|
||||
{
|
||||
|
@ -26,9 +26,8 @@ namespace QodeAssist::Templates {
|
||||
class CodeLlamaQMLFim : public LLMCore::PromptTemplate
|
||||
{
|
||||
public:
|
||||
LLMCore::TemplateType type() const override { return LLMCore::TemplateType::Fim; }
|
||||
LLMCore::TemplateType type() const override { return LLMCore::TemplateType::FIM; }
|
||||
QString name() const override { return "CodeLlama QML FIM"; }
|
||||
QString promptTemplate() const override { return "<SUF>%1<PRE>%2<MID>"; }
|
||||
QStringList stopWords() const override
|
||||
{
|
||||
return QStringList() << "<SUF>" << "<PRE>" << "</PRE>" << "</SUF>" << "< EOT >" << "\\end"
|
||||
@ -36,8 +35,9 @@ public:
|
||||
}
|
||||
void prepareRequest(QJsonObject &request, const LLMCore::ContextData &context) const override
|
||||
{
|
||||
QString formattedPrompt = promptTemplate().arg(context.suffix, context.prefix);
|
||||
request["prompt"] = formattedPrompt;
|
||||
request["prompt"] = QString("<SUF>%1<PRE>%2<MID>")
|
||||
.arg(context.suffix.value_or(""), context.prefix.value_or(""));
|
||||
request["system"] = context.systemPrompt.value_or("");
|
||||
}
|
||||
QString description() const override
|
||||
{
|
||||
|
@ -29,30 +29,30 @@ class Llama2 : public LLMCore::PromptTemplate
|
||||
public:
|
||||
QString name() const override { return "Llama 2"; }
|
||||
LLMCore::TemplateType type() const override { return LLMCore::TemplateType::Chat; }
|
||||
QString promptTemplate() const override { return {}; }
|
||||
QStringList stopWords() const override { return QStringList() << "[INST]"; }
|
||||
void prepareRequest(QJsonObject &request, const LLMCore::ContextData &context) const override
|
||||
{
|
||||
QJsonArray messages = request["messages"].toArray();
|
||||
QJsonArray messages;
|
||||
|
||||
for (int i = 0; i < messages.size(); ++i) {
|
||||
QJsonObject message = messages[i].toObject();
|
||||
QString role = message["role"].toString();
|
||||
QString content = message["content"].toString();
|
||||
QString fullContent;
|
||||
|
||||
QString formattedContent;
|
||||
if (role == "system") {
|
||||
formattedContent = QString("[INST]<<SYS>>\n%1\n<</SYS>>[/INST]\n").arg(content);
|
||||
} else if (role == "user") {
|
||||
formattedContent = QString("[INST]%1[/INST]\n").arg(content);
|
||||
} else if (role == "assistant") {
|
||||
formattedContent = content + "\n";
|
||||
}
|
||||
|
||||
message["content"] = formattedContent;
|
||||
messages[i] = message;
|
||||
if (context.systemPrompt) {
|
||||
fullContent
|
||||
+= QString("[INST]<<SYS>>\n%1\n<</SYS>>[/INST]\n").arg(context.systemPrompt.value());
|
||||
}
|
||||
|
||||
if (context.history) {
|
||||
for (const auto &msg : context.history.value()) {
|
||||
if (msg.role == "user") {
|
||||
fullContent += QString("[INST]%1[/INST]\n").arg(msg.content);
|
||||
} else if (msg.role == "assistant") {
|
||||
fullContent += msg.content + "\n";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
messages.append(QJsonObject{{"role", "user"}, {"content", fullContent}});
|
||||
|
||||
request["messages"] = messages;
|
||||
}
|
||||
QString description() const override
|
||||
|
@ -30,24 +30,30 @@ class Llama3 : public LLMCore::PromptTemplate
|
||||
public:
|
||||
QString name() const override { return "Llama 3"; }
|
||||
LLMCore::TemplateType type() const override { return LLMCore::TemplateType::Chat; }
|
||||
QString promptTemplate() const override { return ""; }
|
||||
QStringList stopWords() const override
|
||||
{
|
||||
return QStringList() << "<|start_header_id|>" << "<|end_header_id|>" << "<|eot_id|>";
|
||||
}
|
||||
void prepareRequest(QJsonObject &request, const LLMCore::ContextData &context) const override
|
||||
{
|
||||
QJsonArray messages = request["messages"].toArray();
|
||||
QJsonArray messages;
|
||||
|
||||
for (int i = 0; i < messages.size(); ++i) {
|
||||
QJsonObject message = messages[i].toObject();
|
||||
QString role = message["role"].toString();
|
||||
QString content = message["content"].toString();
|
||||
if (context.systemPrompt) {
|
||||
messages.append(QJsonObject{
|
||||
{"role", "system"},
|
||||
{"content",
|
||||
QString("<|start_header_id|>system<|end_header_id|>%2<|eot_id|>")
|
||||
.arg(context.systemPrompt.value())}});
|
||||
}
|
||||
|
||||
message["content"]
|
||||
= QString("<|start_header_id|>%1<|end_header_id|>%2<|eot_id|>").arg(role, content);
|
||||
|
||||
messages[i] = message;
|
||||
if (context.history) {
|
||||
for (const auto &msg : context.history.value()) {
|
||||
messages.append(QJsonObject{
|
||||
{"role", msg.role},
|
||||
{"content",
|
||||
QString("<|start_header_id|>%1<|end_header_id|>%2<|eot_id|>")
|
||||
.arg(msg.role, msg.content)}});
|
||||
}
|
||||
}
|
||||
|
||||
request["messages"] = messages;
|
||||
|
69
templates/MistralAI.hpp
Normal file
69
templates/MistralAI.hpp
Normal file
@ -0,0 +1,69 @@
|
||||
/*
|
||||
* Copyright (C) 2024 Petr Mironychev
|
||||
*
|
||||
* This file is part of QodeAssist.
|
||||
*
|
||||
* QodeAssist is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* QodeAssist is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with QodeAssist. If not, see <https://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <QJsonArray>
|
||||
|
||||
#include "llmcore/PromptTemplate.hpp"
|
||||
|
||||
namespace QodeAssist::Templates {
|
||||
|
||||
class MistralAIFim : public LLMCore::PromptTemplate
|
||||
{
|
||||
public:
|
||||
LLMCore::TemplateType type() const override { return LLMCore::TemplateType::FIM; }
|
||||
QString name() const override { return "Mistral AI FIM"; }
|
||||
QStringList stopWords() const override { return QStringList(); }
|
||||
void prepareRequest(QJsonObject &request, const LLMCore::ContextData &context) const override
|
||||
{
|
||||
request["prompt"] = context.prefix.value_or("");
|
||||
request["suffix"] = context.suffix.value_or("");
|
||||
}
|
||||
QString description() const override { return "template will take from ollama modelfile"; }
|
||||
};
|
||||
|
||||
class MistralAIChat : public LLMCore::PromptTemplate
|
||||
{
|
||||
public:
|
||||
LLMCore::TemplateType type() const override { return LLMCore::TemplateType::Chat; }
|
||||
QString name() const override { return "Mistral AI Chat"; }
|
||||
QStringList stopWords() const override { return QStringList(); }
|
||||
|
||||
void prepareRequest(QJsonObject &request, const LLMCore::ContextData &context) const override
|
||||
{
|
||||
QJsonArray messages;
|
||||
|
||||
if (context.systemPrompt) {
|
||||
messages.append(
|
||||
QJsonObject{{"role", "system"}, {"content", context.systemPrompt.value()}});
|
||||
}
|
||||
|
||||
if (context.history) {
|
||||
for (const auto &msg : context.history.value()) {
|
||||
messages.append(QJsonObject{{"role", msg.role}, {"content", msg.content}});
|
||||
}
|
||||
}
|
||||
|
||||
request["messages"] = messages;
|
||||
}
|
||||
QString description() const override { return "template will take from ollama modelfile"; }
|
||||
};
|
||||
|
||||
} // namespace QodeAssist::Templates
|
@ -25,31 +25,44 @@
|
||||
|
||||
namespace QodeAssist::Templates {
|
||||
|
||||
class OllamaAutoFim : public LLMCore::PromptTemplate
|
||||
class OllamaFim : public LLMCore::PromptTemplate
|
||||
{
|
||||
public:
|
||||
LLMCore::TemplateType type() const override { return LLMCore::TemplateType::Fim; }
|
||||
QString name() const override { return "Ollama Auto FIM"; }
|
||||
QString promptTemplate() const override { return {}; }
|
||||
LLMCore::TemplateType type() const override { return LLMCore::TemplateType::FIM; }
|
||||
QString name() const override { return "Ollama FIM"; }
|
||||
QStringList stopWords() const override { return QStringList(); }
|
||||
void prepareRequest(QJsonObject &request, const LLMCore::ContextData &context) const override
|
||||
{
|
||||
request["prompt"] = context.prefix;
|
||||
request["suffix"] = context.suffix;
|
||||
request["prompt"] = context.prefix.value_or("");
|
||||
request["suffix"] = context.suffix.value_or("");
|
||||
request["system"] = context.systemPrompt.value_or("");
|
||||
}
|
||||
QString description() const override { return "template will take from ollama modelfile"; }
|
||||
};
|
||||
|
||||
class OllamaAutoChat : public LLMCore::PromptTemplate
|
||||
class OllamaChat : public LLMCore::PromptTemplate
|
||||
{
|
||||
public:
|
||||
LLMCore::TemplateType type() const override { return LLMCore::TemplateType::Chat; }
|
||||
QString name() const override { return "Ollama Auto Chat"; }
|
||||
QString promptTemplate() const override { return {}; }
|
||||
QString name() const override { return "Ollama Chat"; }
|
||||
QStringList stopWords() const override { return QStringList(); }
|
||||
|
||||
void prepareRequest(QJsonObject &request, const LLMCore::ContextData &context) const override
|
||||
{
|
||||
QJsonArray messages;
|
||||
|
||||
if (context.systemPrompt) {
|
||||
messages.append(
|
||||
QJsonObject{{"role", "system"}, {"content", context.systemPrompt.value()}});
|
||||
}
|
||||
|
||||
if (context.history) {
|
||||
for (const auto &msg : context.history.value()) {
|
||||
messages.append(QJsonObject{{"role", msg.role}, {"content", msg.content}});
|
||||
}
|
||||
}
|
||||
|
||||
request["messages"] = messages;
|
||||
}
|
||||
QString description() const override { return "template will take from ollama modelfile"; }
|
||||
};
|
||||
|
@ -30,9 +30,24 @@ class OpenAI : public LLMCore::PromptTemplate
|
||||
public:
|
||||
LLMCore::TemplateType type() const override { return LLMCore::TemplateType::Chat; }
|
||||
QString name() const override { return "OpenAI"; }
|
||||
QString promptTemplate() const override { return {}; }
|
||||
QStringList stopWords() const override { return QStringList(); }
|
||||
void prepareRequest(QJsonObject &request, const LLMCore::ContextData &context) const override {}
|
||||
void prepareRequest(QJsonObject &request, const LLMCore::ContextData &context) const override
|
||||
{
|
||||
QJsonArray messages;
|
||||
|
||||
if (context.systemPrompt) {
|
||||
messages.append(
|
||||
QJsonObject{{"role", "system"}, {"content", context.systemPrompt.value()}});
|
||||
}
|
||||
|
||||
if (context.history) {
|
||||
for (const auto &msg : context.history.value()) {
|
||||
messages.append(QJsonObject{{"role", msg.role}, {"content", msg.content}});
|
||||
}
|
||||
}
|
||||
|
||||
request["messages"] = messages;
|
||||
}
|
||||
QString description() const override { return "OpenAI"; }
|
||||
};
|
||||
|
||||
|
@ -25,15 +25,29 @@
|
||||
|
||||
namespace QodeAssist::Templates {
|
||||
|
||||
class BasicChat : public LLMCore::PromptTemplate
|
||||
class OpenAICompatible : public LLMCore::PromptTemplate
|
||||
{
|
||||
public:
|
||||
LLMCore::TemplateType type() const override { return LLMCore::TemplateType::Chat; }
|
||||
QString name() const override { return "Basic Chat"; }
|
||||
QString promptTemplate() const override { return {}; }
|
||||
QString name() const override { return "OpenAI Compatible"; }
|
||||
QStringList stopWords() const override { return QStringList(); }
|
||||
void prepareRequest(QJsonObject &request, const LLMCore::ContextData &context) const override
|
||||
{}
|
||||
{
|
||||
QJsonArray messages;
|
||||
|
||||
if (context.systemPrompt) {
|
||||
messages.append(
|
||||
QJsonObject{{"role", "system"}, {"content", context.systemPrompt.value()}});
|
||||
}
|
||||
|
||||
if (context.history) {
|
||||
for (const auto &msg : context.history.value()) {
|
||||
messages.append(QJsonObject{{"role", msg.role}, {"content", msg.content}});
|
||||
}
|
||||
}
|
||||
|
||||
request["messages"] = messages;
|
||||
}
|
||||
QString description() const override { return "chat without tokens"; }
|
||||
};
|
||||
|
@ -28,16 +28,13 @@ class QwenFim : public LLMCore::PromptTemplate
|
||||
{
|
||||
public:
|
||||
QString name() const override { return "Qwen FIM"; }
|
||||
LLMCore::TemplateType type() const override { return LLMCore::TemplateType::Fim; }
|
||||
QString promptTemplate() const override
|
||||
{
|
||||
return "<|fim_prefix|>%1<|fim_suffix|>%2<|fim_middle|>";
|
||||
}
|
||||
LLMCore::TemplateType type() const override { return LLMCore::TemplateType::FIM; }
|
||||
QStringList stopWords() const override { return QStringList() << "<|endoftext|>" << "<|EOT|>"; }
|
||||
void prepareRequest(QJsonObject &request, const LLMCore::ContextData &context) const override
|
||||
{
|
||||
QString formattedPrompt = promptTemplate().arg(context.prefix, context.suffix);
|
||||
request["prompt"] = formattedPrompt;
|
||||
request["prompt"] = QString("<|fim_prefix|>%1<|fim_suffix|>%2<|fim_middle|>")
|
||||
.arg(context.prefix.value_or(""), context.suffix.value_or(""));
|
||||
request["system"] = context.systemPrompt.value_or("");
|
||||
}
|
||||
QString description() const override
|
||||
{
|
||||
|
@ -26,9 +26,8 @@ namespace QodeAssist::Templates {
|
||||
class StarCoder2Fim : public LLMCore::PromptTemplate
|
||||
{
|
||||
public:
|
||||
LLMCore::TemplateType type() const override { return LLMCore::TemplateType::Fim; }
|
||||
LLMCore::TemplateType type() const override { return LLMCore::TemplateType::FIM; }
|
||||
QString name() const override { return "StarCoder2 FIM"; }
|
||||
QString promptTemplate() const override { return "<fim_prefix>%1<fim_suffix>%2<fim_middle>"; }
|
||||
QStringList stopWords() const override
|
||||
{
|
||||
return QStringList() << "<|endoftext|>" << "<file_sep>" << "<fim_prefix>" << "<fim_suffix>"
|
||||
@ -36,8 +35,9 @@ public:
|
||||
}
|
||||
void prepareRequest(QJsonObject &request, const LLMCore::ContextData &context) const override
|
||||
{
|
||||
QString formattedPrompt = promptTemplate().arg(context.prefix, context.suffix);
|
||||
request["prompt"] = formattedPrompt;
|
||||
request["prompt"] = QString("<fim_prefix>%1<fim_suffix>%2<fim_middle>")
|
||||
.arg(context.prefix.value_or(""), context.suffix.value_or(""));
|
||||
request["system"] = context.systemPrompt.value_or("");
|
||||
}
|
||||
QString description() const override
|
||||
{
|
||||
|
@ -21,17 +21,18 @@
|
||||
|
||||
#include "llmcore/PromptTemplateManager.hpp"
|
||||
#include "templates/Alpaca.hpp"
|
||||
#include "templates/BasicChat.hpp"
|
||||
#include "templates/ChatML.hpp"
|
||||
#include "templates/Claude.hpp"
|
||||
#include "templates/CodeLlamaFim.hpp"
|
||||
#include "templates/CodeLlamaQMLFim.hpp"
|
||||
#include "templates/CustomFimTemplate.hpp"
|
||||
#include "templates/DeepSeekCoderFim.hpp"
|
||||
#include "templates/Llama2.hpp"
|
||||
#include "templates/Llama3.hpp"
|
||||
#include "templates/MistralAI.hpp"
|
||||
#include "templates/Ollama.hpp"
|
||||
#include "templates/OpenAI.hpp"
|
||||
#include "templates/OpenAICompatible.hpp"
|
||||
// #include "templates/CustomFimTemplate.hpp"
|
||||
// #include "templates/DeepSeekCoderFim.hpp"
|
||||
#include "templates/Llama2.hpp"
|
||||
#include "templates/Llama3.hpp"
|
||||
#include "templates/Qwen.hpp"
|
||||
#include "templates/StarCoder2Fim.hpp"
|
||||
|
||||
@ -41,20 +42,22 @@ inline void registerTemplates()
|
||||
{
|
||||
auto &templateManager = LLMCore::PromptTemplateManager::instance();
|
||||
templateManager.registerTemplate<CodeLlamaFim>();
|
||||
templateManager.registerTemplate<StarCoder2Fim>();
|
||||
templateManager.registerTemplate<DeepSeekCoderFim>();
|
||||
templateManager.registerTemplate<CustomTemplate>();
|
||||
templateManager.registerTemplate<QwenFim>();
|
||||
templateManager.registerTemplate<OllamaAutoFim>();
|
||||
templateManager.registerTemplate<OllamaAutoChat>();
|
||||
templateManager.registerTemplate<BasicChat>();
|
||||
templateManager.registerTemplate<Llama3>();
|
||||
templateManager.registerTemplate<ChatML>();
|
||||
templateManager.registerTemplate<Alpaca>();
|
||||
templateManager.registerTemplate<Llama2>();
|
||||
templateManager.registerTemplate<OllamaFim>();
|
||||
templateManager.registerTemplate<OllamaChat>();
|
||||
templateManager.registerTemplate<Claude>();
|
||||
templateManager.registerTemplate<OpenAI>();
|
||||
templateManager.registerTemplate<MistralAIFim>();
|
||||
templateManager.registerTemplate<MistralAIChat>();
|
||||
templateManager.registerTemplate<CodeLlamaQMLFim>();
|
||||
templateManager.registerTemplate<ChatML>();
|
||||
templateManager.registerTemplate<Llama2>();
|
||||
templateManager.registerTemplate<Llama3>();
|
||||
templateManager.registerTemplate<StarCoder2Fim>();
|
||||
// templateManager.registerTemplate<DeepSeekCoderFim>();
|
||||
// templateManager.registerTemplate<CustomTemplate>();
|
||||
templateManager.registerTemplate<QwenFim>();
|
||||
templateManager.registerTemplate<OpenAICompatible>();
|
||||
templateManager.registerTemplate<Alpaca>();
|
||||
}
|
||||
|
||||
} // namespace QodeAssist::Templates
|
||||
|
Loading…
x
Reference in New Issue
Block a user