feat: Add filter templates for each provider

This commit is contained in:
Petr Mironychev 2025-02-23 01:41:47 +01:00
parent d96f44d42c
commit e924029ec2
36 changed files with 291 additions and 25 deletions

View File

@ -137,9 +137,14 @@ void ConfigurationManager::selectTemplate()
const bool isCodeCompletion = (settingsButton == &m_generalSettings.ccSelectTemplate);
const bool isPreset1 = (settingsButton == &m_generalSettings.ccPreset1SelectTemplate);
const QString providerName = isCodeCompletion ? m_generalSettings.ccProvider.volatileValue()
: isPreset1 ? m_generalSettings.ccPreset1Provider.volatileValue()
: m_generalSettings.caProvider.volatileValue();
auto providerID = m_providersManager.getProviderByName(providerName)->providerID();
const auto templateList = isCodeCompletion || isPreset1 ? m_templateManger.fimTemplatesNames()
: m_templateManger.chatTemplatesNames();
const auto templateList = isCodeCompletion || isPreset1
? m_templateManger.getFimTemplatesForProvider(providerID)
: m_templateManger.getChatTemplatesForProvider(providerID);
auto &targetSettings = isCodeCompletion ? m_generalSettings.ccTemplate
: isPreset1 ? m_generalSettings.ccPreset1Template

View File

@ -10,6 +10,7 @@ add_library(LLMCore STATIC
OllamaMessage.hpp OllamaMessage.cpp
OpenAIMessage.hpp OpenAIMessage.cpp
ValidationUtils.hpp ValidationUtils.cpp
ProviderID.hpp
)
target_link_libraries(LLMCore

View File

@ -24,6 +24,7 @@
#include <QString>
#include "ContextData.hpp"
#include "ProviderID.hpp"
namespace QodeAssist::LLMCore {
@ -35,9 +36,9 @@ public:
virtual ~PromptTemplate() = default;
virtual TemplateType type() const = 0;
virtual QString name() const = 0;
// virtual QString promptTemplate() const = 0;
virtual QStringList stopWords() const = 0;
virtual void prepareRequest(QJsonObject &request, const ContextData &context) const = 0;
virtual QString description() const = 0;
virtual bool isSupportProvider(ProviderID id) const = 0;
};
} // namespace QodeAssist::LLMCore

View File

@ -37,6 +37,32 @@ QStringList PromptTemplateManager::chatTemplatesNames() const
return m_chatTemplates.keys();
}
QStringList PromptTemplateManager::getFimTemplatesForProvider(ProviderID id)
{
QStringList templateList;
for (const auto tmpl : m_fimTemplates) {
if (tmpl->isSupportProvider(id)) {
templateList.append(tmpl->name());
}
}
return templateList;
}
QStringList PromptTemplateManager::getChatTemplatesForProvider(ProviderID id)
{
QStringList templateList;
for (const auto tmpl : m_chatTemplates) {
if (tmpl->isSupportProvider(id)) {
templateList.append(tmpl->name());
}
}
return templateList;
}
PromptTemplateManager::~PromptTemplateManager()
{
qDeleteAll(m_fimTemplates);
@ -44,11 +70,15 @@ PromptTemplateManager::~PromptTemplateManager()
PromptTemplate *PromptTemplateManager::getFimTemplateByName(const QString &templateName)
{
if (!m_fimTemplates.contains(templateName))
return m_fimTemplates.first();
return m_fimTemplates[templateName];
}
PromptTemplate *PromptTemplateManager::getChatTemplateByName(const QString &templateName)
{
if (!m_chatTemplates.contains(templateName))
return m_chatTemplates.first();
return m_chatTemplates[templateName];
}

View File

@ -51,6 +51,9 @@ public:
QStringList fimTemplatesNames() const;
QStringList chatTemplatesNames() const;
QStringList getFimTemplatesForProvider(ProviderID id);
QStringList getChatTemplatesForProvider(ProviderID id);
private:
PromptTemplateManager() = default;
PromptTemplateManager(const PromptTemplateManager &) = delete;

View File

@ -54,6 +54,7 @@ public:
virtual QList<QString> validateRequest(const QJsonObject &request, TemplateType type) = 0;
virtual QString apiKey() const = 0;
virtual void prepareNetworkRequest(QNetworkRequest &networkRequest) const = 0;
virtual ProviderID providerID() const = 0;
};
} // namespace QodeAssist::LLMCore

33
llmcore/ProviderID.hpp Normal file
View File

@ -0,0 +1,33 @@
/*
* Copyright (C) 2024 Petr Mironychev
*
* This file is part of QodeAssist.
*
* QodeAssist is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* QodeAssist is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with QodeAssist. If not, see <https://www.gnu.org/licenses/>.
*/
namespace QodeAssist::LLMCore {
enum class ProviderID {
Any,
Ollama,
LMStudio,
Claude,
OpenAI,
OpenAICompatible,
MistralAI,
OpenRouter
};
}

View File

@ -39,6 +39,8 @@ ProvidersManager::~ProvidersManager()
Provider *ProvidersManager::getProviderByName(const QString &providerName)
{
if (!m_providers.contains(providerName))
return m_providers.first();
return m_providers[providerName];
}

View File

@ -65,9 +65,9 @@ void ClaudeProvider::prepareRequest(
LLMCore::ContextData context,
LLMCore::RequestType type)
{
// if (!isSupportedTemplate(prompt->name())) {
// LOG_MESSAGE(QString("Provider doesn't support %1 template").arg(prompt->name()));
// }
if (!prompt->isSupportProvider(providerID())) {
LOG_MESSAGE(QString("Template %1 doesn't support %2 provider").arg(name(), prompt->name()));
}
prompt->prepareRequest(request, context);
@ -213,4 +213,9 @@ void ClaudeProvider::prepareNetworkRequest(QNetworkRequest &networkRequest) cons
}
}
LLMCore::ProviderID ClaudeProvider::providerID() const
{
return LLMCore::ProviderID::Claude;
}
} // namespace QodeAssist::Providers

View File

@ -41,6 +41,7 @@ public:
QList<QString> validateRequest(const QJsonObject &request, LLMCore::TemplateType type) override;
QString apiKey() const override;
void prepareNetworkRequest(QNetworkRequest &networkRequest) const override;
LLMCore::ProviderID providerID() const override;
};
} // namespace QodeAssist::Providers

View File

@ -168,15 +168,20 @@ void LMStudioProvider::prepareNetworkRequest(QNetworkRequest &networkRequest) co
networkRequest.setHeader(QNetworkRequest::ContentTypeHeader, "application/json");
}
LLMCore::ProviderID LMStudioProvider::providerID() const
{
return LLMCore::ProviderID::LMStudio;
}
void QodeAssist::Providers::LMStudioProvider::prepareRequest(
QJsonObject &request,
LLMCore::PromptTemplate *prompt,
LLMCore::ContextData context,
LLMCore::RequestType type)
{
// if (!isSupportedTemplate(prompt->name())) {
// LOG_MESSAGE(QString("Provider doesn't support %1 template").arg(prompt->name()));
// }
if (!prompt->isSupportProvider(providerID())) {
LOG_MESSAGE(QString("Template %1 doesn't support %2 provider").arg(name(), prompt->name()));
}
prompt->prepareRequest(request, context);

View File

@ -41,6 +41,7 @@ public:
QList<QString> validateRequest(const QJsonObject &request, LLMCore::TemplateType type) override;
QString apiKey() const override;
void prepareNetworkRequest(QNetworkRequest &networkRequest) const override;
LLMCore::ProviderID providerID() const override;
};
} // namespace QodeAssist::Providers

View File

@ -171,15 +171,20 @@ void MistralAIProvider::prepareNetworkRequest(QNetworkRequest &networkRequest) c
}
}
LLMCore::ProviderID MistralAIProvider::providerID() const
{
return LLMCore::ProviderID::MistralAI;
}
void MistralAIProvider::prepareRequest(
QJsonObject &request,
LLMCore::PromptTemplate *prompt,
LLMCore::ContextData context,
LLMCore::RequestType type)
{
// if (!isSupportedTemplate(prompt->name())) {
// LOG_MESSAGE(QString("Provider doesn't support %1 template").arg(prompt->name()));
// }
if (!prompt->isSupportProvider(providerID())) {
LOG_MESSAGE(QString("Template %1 doesn't support %2 provider").arg(name(), prompt->name()));
}
prompt->prepareRequest(request, context);

View File

@ -41,6 +41,7 @@ public:
QList<QString> validateRequest(const QJsonObject &request, LLMCore::TemplateType type) override;
QString apiKey() const override;
void prepareNetworkRequest(QNetworkRequest &networkRequest) const override;
LLMCore::ProviderID providerID() const override;
};
} // namespace QodeAssist::Providers

View File

@ -64,9 +64,9 @@ void OllamaProvider::prepareRequest(
LLMCore::ContextData context,
LLMCore::RequestType type)
{
// if (!isSupportedTemplate(prompt->name())) {
// LOG_MESSAGE(QString("Provider doesn't support %1 template").arg(prompt->name()));
// }
if (!prompt->isSupportProvider(providerID())) {
LOG_MESSAGE(QString("Template %1 doesn't support %2 provider").arg(name(), prompt->name()));
}
prompt->prepareRequest(request, context);
@ -213,4 +213,9 @@ void OllamaProvider::prepareNetworkRequest(QNetworkRequest &networkRequest) cons
networkRequest.setHeader(QNetworkRequest::ContentTypeHeader, "application/json");
}
LLMCore::ProviderID OllamaProvider::providerID() const
{
return LLMCore::ProviderID::Ollama;
}
} // namespace QodeAssist::Providers

View File

@ -41,6 +41,7 @@ public:
QList<QString> validateRequest(const QJsonObject &request, LLMCore::TemplateType type) override;
QString apiKey() const override;
void prepareNetworkRequest(QNetworkRequest &networkRequest) const override;
LLMCore::ProviderID providerID() const override;
};
} // namespace QodeAssist::Providers

View File

@ -65,9 +65,9 @@ void OpenAICompatProvider::prepareRequest(
LLMCore::ContextData context,
LLMCore::RequestType type)
{
// if (!isSupportedTemplate(prompt->name())) {
// LOG_MESSAGE(QString("Provider doesn't support %1 template").arg(prompt->name()));
// }
if (!prompt->isSupportProvider(providerID())) {
LOG_MESSAGE(QString("Template %1 doesn't support %2 provider").arg(name(), prompt->name()));
}
prompt->prepareRequest(request, context);
@ -180,4 +180,9 @@ void OpenAICompatProvider::prepareNetworkRequest(QNetworkRequest &networkRequest
}
}
LLMCore::ProviderID OpenAICompatProvider::providerID() const
{
return LLMCore::ProviderID::OpenAICompatible;
}
} // namespace QodeAssist::Providers

View File

@ -41,6 +41,7 @@ public:
QList<QString> validateRequest(const QJsonObject &request, LLMCore::TemplateType type) override;
QString apiKey() const override;
void prepareNetworkRequest(QNetworkRequest &networkRequest) const override;
LLMCore::ProviderID providerID() const override;
};
} // namespace QodeAssist::Providers

View File

@ -66,9 +66,9 @@ void OpenAIProvider::prepareRequest(
LLMCore::ContextData context,
LLMCore::RequestType type)
{
// if (!isSupportedTemplate(prompt->name())) {
// LOG_MESSAGE(QString("Provider doesn't support %1 template").arg(prompt->name()));
// }
if (!prompt->isSupportProvider(providerID())) {
LOG_MESSAGE(QString("Template %1 doesn't support %2 provider").arg(name(), prompt->name()));
}
prompt->prepareRequest(request, context);
@ -216,4 +216,9 @@ void OpenAIProvider::prepareNetworkRequest(QNetworkRequest &networkRequest) cons
}
}
LLMCore::ProviderID OpenAIProvider::providerID() const
{
return LLMCore::ProviderID::OpenAI;
}
} // namespace QodeAssist::Providers

View File

@ -41,6 +41,7 @@ public:
QList<QString> validateRequest(const QJsonObject &request, LLMCore::TemplateType type) override;
QString apiKey() const override;
void prepareNetworkRequest(QNetworkRequest &networkRequest) const override;
LLMCore::ProviderID providerID() const override;
};
} // namespace QodeAssist::Providers

View File

@ -19,8 +19,6 @@
#include "OpenRouterAIProvider.hpp"
#include "settings/ChatAssistantSettings.hpp"
#include "settings/CodeCompletionSettings.hpp"
#include "settings/ProviderSettings.hpp"
#include <QJsonArray>
@ -99,4 +97,9 @@ QString OpenRouterProvider::apiKey() const
return Settings::providerSettings().openRouterApiKey();
}
LLMCore::ProviderID OpenRouterProvider::providerID() const
{
return LLMCore::ProviderID::OpenRouter;
}
} // namespace QodeAssist::Providers

View File

@ -31,6 +31,7 @@ public:
QString url() const override;
bool handleResponse(QNetworkReply *reply, QString &accumulatedResponse) override;
QString apiKey() const override;
LLMCore::ProviderID providerID() const override;
};
} // namespace QodeAssist::Providers

View File

@ -61,6 +61,19 @@ public:
{
return "The message will contain the following tokens: ### Instruction:\n### Response:\n";
}
bool isSupportProvider(LLMCore::ProviderID id) const override
{
switch (id) {
case QodeAssist::LLMCore::ProviderID::Ollama:
case QodeAssist::LLMCore::ProviderID::LMStudio:
case QodeAssist::LLMCore::ProviderID::OpenRouter:
case QodeAssist::LLMCore::ProviderID::OpenAICompatible:
return true;
default:
return false;
}
}
};
} // namespace QodeAssist::Templates

View File

@ -60,6 +60,18 @@ public:
{
return "The message will contain the following tokens: <|im_start|>%1\n%2\n<|im_end|>";
}
bool isSupportProvider(LLMCore::ProviderID id) const override
{
switch (id) {
case QodeAssist::LLMCore::ProviderID::Ollama:
case QodeAssist::LLMCore::ProviderID::LMStudio:
case QodeAssist::LLMCore::ProviderID::OpenRouter:
case QodeAssist::LLMCore::ProviderID::OpenAICompatible:
return true;
default:
return false;
}
}
};
} // namespace QodeAssist::Templates

View File

@ -50,6 +50,15 @@ public:
request["messages"] = messages;
}
QString description() const override { return "Claude"; }
bool isSupportProvider(LLMCore::ProviderID id) const override
{
switch (id) {
case QodeAssist::LLMCore::ProviderID::Claude:
return true;
default:
return false;
}
}
};
} // namespace QodeAssist::Templates

View File

@ -42,6 +42,15 @@ public:
{
return "The message will contain the following tokens: <PRE> %1 <SUF>%2 <MID>";
}
bool isSupportProvider(LLMCore::ProviderID id) const override
{
switch (id) {
case QodeAssist::LLMCore::ProviderID::Ollama:
return true;
default:
return false;
}
}
};
} // namespace QodeAssist::Templates

View File

@ -43,6 +43,15 @@ public:
{
return "The message will contain the following tokens: <SUF>%1<PRE>%2<MID>";
}
bool isSupportProvider(LLMCore::ProviderID id) const override
{
switch (id) {
case QodeAssist::LLMCore::ProviderID::Ollama:
return true;
default:
return false;
}
}
};
} // namespace QodeAssist::Templates

View File

@ -59,6 +59,18 @@ public:
{
return "The message will contain the following tokens: [INST]%1[/INST]\n";
}
bool isSupportProvider(LLMCore::ProviderID id) const override
{
switch (id) {
case QodeAssist::LLMCore::ProviderID::Ollama:
case QodeAssist::LLMCore::ProviderID::LMStudio:
case QodeAssist::LLMCore::ProviderID::OpenRouter:
case QodeAssist::LLMCore::ProviderID::OpenAICompatible:
return true;
default:
return false;
}
}
};
} // namespace QodeAssist::Templates

View File

@ -63,6 +63,18 @@ public:
return "The message will contain the following tokens: "
"<|start_header_id|>%1<|end_header_id|>%2<|eot_id|>";
}
bool isSupportProvider(LLMCore::ProviderID id) const override
{
switch (id) {
case QodeAssist::LLMCore::ProviderID::Ollama:
case QodeAssist::LLMCore::ProviderID::LMStudio:
case QodeAssist::LLMCore::ProviderID::OpenRouter:
case QodeAssist::LLMCore::ProviderID::OpenAICompatible:
return true;
default:
return false;
}
}
};
} // namespace QodeAssist::Templates

View File

@ -37,6 +37,15 @@ public:
request["suffix"] = context.suffix.value_or("");
}
QString description() const override { return "template will take from ollama modelfile"; }
bool isSupportProvider(LLMCore::ProviderID id) const override
{
switch (id) {
case QodeAssist::LLMCore::ProviderID::MistralAI:
return true;
default:
return false;
}
}
};
class MistralAIChat : public LLMCore::PromptTemplate
@ -64,6 +73,15 @@ public:
request["messages"] = messages;
}
QString description() const override { return "template will take from ollama modelfile"; }
bool isSupportProvider(LLMCore::ProviderID id) const override
{
switch (id) {
case QodeAssist::LLMCore::ProviderID::MistralAI:
return true;
default:
return false;
}
}
};
} // namespace QodeAssist::Templates

View File

@ -38,6 +38,15 @@ public:
request["system"] = context.systemPrompt.value_or("");
}
QString description() const override { return "template will take from ollama modelfile"; }
bool isSupportProvider(LLMCore::ProviderID id) const override
{
switch (id) {
case QodeAssist::LLMCore::ProviderID::Ollama:
return true;
default:
return false;
}
}
};
class OllamaChat : public LLMCore::PromptTemplate
@ -65,6 +74,15 @@ public:
request["messages"] = messages;
}
QString description() const override { return "template will take from ollama modelfile"; }
bool isSupportProvider(LLMCore::ProviderID id) const override
{
switch (id) {
case QodeAssist::LLMCore::ProviderID::Ollama:
return true;
default:
return false;
}
}
};
} // namespace QodeAssist::Templates

View File

@ -49,6 +49,15 @@ public:
request["messages"] = messages;
}
QString description() const override { return "OpenAI"; }
bool isSupportProvider(LLMCore::ProviderID id) const override
{
switch (id) {
case QodeAssist::LLMCore::ProviderID::OpenAI:
return true;
default:
return false;
}
}
};
} // namespace QodeAssist::Templates

View File

@ -49,6 +49,17 @@ public:
request["messages"] = messages;
}
QString description() const override { return "chat without tokens"; }
bool isSupportProvider(LLMCore::ProviderID id) const override
{
switch (id) {
case QodeAssist::LLMCore::ProviderID::OpenAICompatible:
case QodeAssist::LLMCore::ProviderID::OpenRouter:
case QodeAssist::LLMCore::ProviderID::LMStudio:
return true;
default:
return false;
}
}
};
} // namespace QodeAssist::Templates

View File

@ -41,6 +41,15 @@ public:
return "The message will contain the following tokens: "
"<|fim_prefix|>%1<|fim_suffix|>%2<|fim_middle|>";
}
bool isSupportProvider(LLMCore::ProviderID id) const override
{
switch (id) {
case QodeAssist::LLMCore::ProviderID::Ollama:
return true;
default:
return false;
}
}
};
} // namespace QodeAssist::Templates

View File

@ -44,6 +44,15 @@ public:
return "The message will contain the following tokens: "
"<fim_prefix>%1<fim_suffix>%2<fim_middle>";
}
bool isSupportProvider(LLMCore::ProviderID id) const override
{
switch (id) {
case QodeAssist::LLMCore::ProviderID::Ollama:
return true;
default:
return false;
}
}
};
} // namespace QodeAssist::Templates

View File

@ -41,9 +41,9 @@ namespace QodeAssist::Templates {
inline void registerTemplates()
{
auto &templateManager = LLMCore::PromptTemplateManager::instance();
templateManager.registerTemplate<CodeLlamaFim>();
templateManager.registerTemplate<OllamaFim>();
templateManager.registerTemplate<OllamaChat>();
templateManager.registerTemplate<OllamaFim>();
templateManager.registerTemplate<CodeLlamaFim>();
templateManager.registerTemplate<Claude>();
templateManager.registerTemplate<OpenAI>();
templateManager.registerTemplate<MistralAIFim>();