refactor: Rework providers and templates logic

This commit is contained in:
Petr Mironychev
2025-02-22 19:39:28 +01:00
committed by GitHub
parent bd25736a55
commit d96f44d42c
44 changed files with 701 additions and 524 deletions

View File

@ -29,33 +29,32 @@ class Alpaca : public LLMCore::PromptTemplate
public:
QString name() const override { return "Alpaca"; }
LLMCore::TemplateType type() const override { return LLMCore::TemplateType::Chat; }
QString promptTemplate() const override { return {}; }
QStringList stopWords() const override
{
return QStringList() << "### Instruction:" << "### Response:";
}
void prepareRequest(QJsonObject &request, const LLMCore::ContextData &context) const override
{
QJsonArray messages = request["messages"].toArray();
QJsonArray messages;
for (int i = 0; i < messages.size(); ++i) {
QJsonObject message = messages[i].toObject();
QString role = message["role"].toString();
QString content = message["content"].toString();
QString fullContent;
QString formattedContent;
if (role == "system") {
formattedContent = content + "\n\n";
} else if (role == "user") {
formattedContent = "### Instruction:\n" + content + "\n\n";
} else if (role == "assistant") {
formattedContent = "### Response:\n" + content + "\n\n";
}
message["content"] = formattedContent;
messages[i] = message;
if (context.systemPrompt) {
fullContent += context.systemPrompt.value() + "\n\n";
}
if (context.history) {
for (const auto &msg : context.history.value()) {
if (msg.role == "user") {
fullContent += QString("### Instruction:\n%1\n\n").arg(msg.content);
} else if (msg.role == "assistant") {
fullContent += QString("### Response:\n%1\n\n").arg(msg.content);
}
}
}
messages.append(QJsonObject{{"role", "user"}, {"content", fullContent}});
request["messages"] = messages;
}
QString description() const override

View File

@ -30,23 +30,28 @@ class ChatML : public LLMCore::PromptTemplate
public:
QString name() const override { return "ChatML"; }
LLMCore::TemplateType type() const override { return LLMCore::TemplateType::Chat; }
QString promptTemplate() const override { return {}; }
QStringList stopWords() const override
{
return QStringList() << "<|im_start|>" << "<|im_end|>";
}
void prepareRequest(QJsonObject &request, const LLMCore::ContextData &context) const override
{
QJsonArray messages = request["messages"].toArray();
QJsonArray messages;
for (int i = 0; i < messages.size(); ++i) {
QJsonObject message = messages[i].toObject();
QString role = message["role"].toString();
QString content = message["content"].toString();
if (context.systemPrompt) {
messages.append(QJsonObject{
{"role", "system"},
{"content",
QString("<|im_start|>system\n%2\n<|im_end|>").arg(context.systemPrompt.value())}});
}
message["content"] = QString("<|im_start|>%1\n%2\n<|im_end|>").arg(role, content);
messages[i] = message;
if (context.history) {
for (const auto &msg : context.history.value()) {
messages.append(QJsonObject{
{"role", msg.role},
{"content",
QString("<|im_start|>%1\n%2\n<|im_end|>").arg(msg.role, msg.content)}});
}
}
request["messages"] = messages;

View File

@ -30,9 +30,25 @@ class Claude : public LLMCore::PromptTemplate
public:
LLMCore::TemplateType type() const override { return LLMCore::TemplateType::Chat; }
QString name() const override { return "Claude"; }
QString promptTemplate() const override { return {}; }
QStringList stopWords() const override { return QStringList(); }
void prepareRequest(QJsonObject &request, const LLMCore::ContextData &context) const override {}
void prepareRequest(QJsonObject &request, const LLMCore::ContextData &context) const override
{
QJsonArray messages;
if (context.systemPrompt) {
request["system"] = context.systemPrompt.value();
}
if (context.history) {
for (const auto &msg : context.history.value()) {
if (msg.role != "system") {
messages.append(QJsonObject{{"role", msg.role}, {"content", msg.content}});
}
}
}
request["messages"] = messages;
}
QString description() const override { return "Claude"; }
};

View File

@ -26,17 +26,17 @@ namespace QodeAssist::Templates {
class CodeLlamaFim : public LLMCore::PromptTemplate
{
public:
LLMCore::TemplateType type() const override { return LLMCore::TemplateType::Fim; }
LLMCore::TemplateType type() const override { return LLMCore::TemplateType::FIM; }
QString name() const override { return "CodeLlama FIM"; }
QString promptTemplate() const override { return "<PRE> %1 <SUF>%2 <MID>"; }
QStringList stopWords() const override
{
return QStringList() << "<EOT>" << "<PRE>" << "<SUF" << "<MID>";
}
void prepareRequest(QJsonObject &request, const LLMCore::ContextData &context) const override
{
QString formattedPrompt = promptTemplate().arg(context.prefix, context.suffix);
request["prompt"] = formattedPrompt;
request["prompt"] = QString("<PRE> %1 <SUF>%2 <MID>")
.arg(context.prefix.value_or(""), context.suffix.value_or(""));
request["system"] = context.systemPrompt.value_or("");
}
QString description() const override
{

View File

@ -26,9 +26,8 @@ namespace QodeAssist::Templates {
class CodeLlamaQMLFim : public LLMCore::PromptTemplate
{
public:
LLMCore::TemplateType type() const override { return LLMCore::TemplateType::Fim; }
LLMCore::TemplateType type() const override { return LLMCore::TemplateType::FIM; }
QString name() const override { return "CodeLlama QML FIM"; }
QString promptTemplate() const override { return "<SUF>%1<PRE>%2<MID>"; }
QStringList stopWords() const override
{
return QStringList() << "<SUF>" << "<PRE>" << "</PRE>" << "</SUF>" << "< EOT >" << "\\end"
@ -36,8 +35,9 @@ public:
}
void prepareRequest(QJsonObject &request, const LLMCore::ContextData &context) const override
{
QString formattedPrompt = promptTemplate().arg(context.suffix, context.prefix);
request["prompt"] = formattedPrompt;
request["prompt"] = QString("<SUF>%1<PRE>%2<MID>")
.arg(context.suffix.value_or(""), context.prefix.value_or(""));
request["system"] = context.systemPrompt.value_or("");
}
QString description() const override
{

View File

@ -29,30 +29,30 @@ class Llama2 : public LLMCore::PromptTemplate
public:
QString name() const override { return "Llama 2"; }
LLMCore::TemplateType type() const override { return LLMCore::TemplateType::Chat; }
QString promptTemplate() const override { return {}; }
QStringList stopWords() const override { return QStringList() << "[INST]"; }
void prepareRequest(QJsonObject &request, const LLMCore::ContextData &context) const override
{
QJsonArray messages = request["messages"].toArray();
QJsonArray messages;
for (int i = 0; i < messages.size(); ++i) {
QJsonObject message = messages[i].toObject();
QString role = message["role"].toString();
QString content = message["content"].toString();
QString fullContent;
QString formattedContent;
if (role == "system") {
formattedContent = QString("[INST]<<SYS>>\n%1\n<</SYS>>[/INST]\n").arg(content);
} else if (role == "user") {
formattedContent = QString("[INST]%1[/INST]\n").arg(content);
} else if (role == "assistant") {
formattedContent = content + "\n";
}
message["content"] = formattedContent;
messages[i] = message;
if (context.systemPrompt) {
fullContent
+= QString("[INST]<<SYS>>\n%1\n<</SYS>>[/INST]\n").arg(context.systemPrompt.value());
}
if (context.history) {
for (const auto &msg : context.history.value()) {
if (msg.role == "user") {
fullContent += QString("[INST]%1[/INST]\n").arg(msg.content);
} else if (msg.role == "assistant") {
fullContent += msg.content + "\n";
}
}
}
messages.append(QJsonObject{{"role", "user"}, {"content", fullContent}});
request["messages"] = messages;
}
QString description() const override

View File

@ -30,24 +30,30 @@ class Llama3 : public LLMCore::PromptTemplate
public:
QString name() const override { return "Llama 3"; }
LLMCore::TemplateType type() const override { return LLMCore::TemplateType::Chat; }
QString promptTemplate() const override { return ""; }
QStringList stopWords() const override
{
return QStringList() << "<|start_header_id|>" << "<|end_header_id|>" << "<|eot_id|>";
}
void prepareRequest(QJsonObject &request, const LLMCore::ContextData &context) const override
{
QJsonArray messages = request["messages"].toArray();
QJsonArray messages;
for (int i = 0; i < messages.size(); ++i) {
QJsonObject message = messages[i].toObject();
QString role = message["role"].toString();
QString content = message["content"].toString();
if (context.systemPrompt) {
messages.append(QJsonObject{
{"role", "system"},
{"content",
QString("<|start_header_id|>system<|end_header_id|>%2<|eot_id|>")
.arg(context.systemPrompt.value())}});
}
message["content"]
= QString("<|start_header_id|>%1<|end_header_id|>%2<|eot_id|>").arg(role, content);
messages[i] = message;
if (context.history) {
for (const auto &msg : context.history.value()) {
messages.append(QJsonObject{
{"role", msg.role},
{"content",
QString("<|start_header_id|>%1<|end_header_id|>%2<|eot_id|>")
.arg(msg.role, msg.content)}});
}
}
request["messages"] = messages;

69
templates/MistralAI.hpp Normal file
View File

@ -0,0 +1,69 @@
/*
* Copyright (C) 2024 Petr Mironychev
*
* This file is part of QodeAssist.
*
* QodeAssist is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* QodeAssist is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with QodeAssist. If not, see <https://www.gnu.org/licenses/>.
*/
#pragma once
#include <QJsonArray>
#include "llmcore/PromptTemplate.hpp"
namespace QodeAssist::Templates {
class MistralAIFim : public LLMCore::PromptTemplate
{
public:
LLMCore::TemplateType type() const override { return LLMCore::TemplateType::FIM; }
QString name() const override { return "Mistral AI FIM"; }
QStringList stopWords() const override { return QStringList(); }
void prepareRequest(QJsonObject &request, const LLMCore::ContextData &context) const override
{
request["prompt"] = context.prefix.value_or("");
request["suffix"] = context.suffix.value_or("");
}
QString description() const override { return "template will take from ollama modelfile"; }
};
class MistralAIChat : public LLMCore::PromptTemplate
{
public:
LLMCore::TemplateType type() const override { return LLMCore::TemplateType::Chat; }
QString name() const override { return "Mistral AI Chat"; }
QStringList stopWords() const override { return QStringList(); }
void prepareRequest(QJsonObject &request, const LLMCore::ContextData &context) const override
{
QJsonArray messages;
if (context.systemPrompt) {
messages.append(
QJsonObject{{"role", "system"}, {"content", context.systemPrompt.value()}});
}
if (context.history) {
for (const auto &msg : context.history.value()) {
messages.append(QJsonObject{{"role", msg.role}, {"content", msg.content}});
}
}
request["messages"] = messages;
}
QString description() const override { return "template will take from ollama modelfile"; }
};
} // namespace QodeAssist::Templates

View File

@ -25,31 +25,44 @@
namespace QodeAssist::Templates {
class OllamaAutoFim : public LLMCore::PromptTemplate
class OllamaFim : public LLMCore::PromptTemplate
{
public:
LLMCore::TemplateType type() const override { return LLMCore::TemplateType::Fim; }
QString name() const override { return "Ollama Auto FIM"; }
QString promptTemplate() const override { return {}; }
LLMCore::TemplateType type() const override { return LLMCore::TemplateType::FIM; }
QString name() const override { return "Ollama FIM"; }
QStringList stopWords() const override { return QStringList(); }
void prepareRequest(QJsonObject &request, const LLMCore::ContextData &context) const override
{
request["prompt"] = context.prefix;
request["suffix"] = context.suffix;
request["prompt"] = context.prefix.value_or("");
request["suffix"] = context.suffix.value_or("");
request["system"] = context.systemPrompt.value_or("");
}
QString description() const override { return "template will take from ollama modelfile"; }
};
class OllamaAutoChat : public LLMCore::PromptTemplate
class OllamaChat : public LLMCore::PromptTemplate
{
public:
LLMCore::TemplateType type() const override { return LLMCore::TemplateType::Chat; }
QString name() const override { return "Ollama Auto Chat"; }
QString promptTemplate() const override { return {}; }
QString name() const override { return "Ollama Chat"; }
QStringList stopWords() const override { return QStringList(); }
void prepareRequest(QJsonObject &request, const LLMCore::ContextData &context) const override
{
QJsonArray messages;
if (context.systemPrompt) {
messages.append(
QJsonObject{{"role", "system"}, {"content", context.systemPrompt.value()}});
}
if (context.history) {
for (const auto &msg : context.history.value()) {
messages.append(QJsonObject{{"role", msg.role}, {"content", msg.content}});
}
}
request["messages"] = messages;
}
QString description() const override { return "template will take from ollama modelfile"; }
};

View File

@ -30,9 +30,24 @@ class OpenAI : public LLMCore::PromptTemplate
public:
LLMCore::TemplateType type() const override { return LLMCore::TemplateType::Chat; }
QString name() const override { return "OpenAI"; }
QString promptTemplate() const override { return {}; }
QStringList stopWords() const override { return QStringList(); }
void prepareRequest(QJsonObject &request, const LLMCore::ContextData &context) const override {}
void prepareRequest(QJsonObject &request, const LLMCore::ContextData &context) const override
{
QJsonArray messages;
if (context.systemPrompt) {
messages.append(
QJsonObject{{"role", "system"}, {"content", context.systemPrompt.value()}});
}
if (context.history) {
for (const auto &msg : context.history.value()) {
messages.append(QJsonObject{{"role", msg.role}, {"content", msg.content}});
}
}
request["messages"] = messages;
}
QString description() const override { return "OpenAI"; }
};

View File

@ -25,15 +25,29 @@
namespace QodeAssist::Templates {
class BasicChat : public LLMCore::PromptTemplate
class OpenAICompatible : public LLMCore::PromptTemplate
{
public:
LLMCore::TemplateType type() const override { return LLMCore::TemplateType::Chat; }
QString name() const override { return "Basic Chat"; }
QString promptTemplate() const override { return {}; }
QString name() const override { return "OpenAI Compatible"; }
QStringList stopWords() const override { return QStringList(); }
void prepareRequest(QJsonObject &request, const LLMCore::ContextData &context) const override
{}
{
QJsonArray messages;
if (context.systemPrompt) {
messages.append(
QJsonObject{{"role", "system"}, {"content", context.systemPrompt.value()}});
}
if (context.history) {
for (const auto &msg : context.history.value()) {
messages.append(QJsonObject{{"role", msg.role}, {"content", msg.content}});
}
}
request["messages"] = messages;
}
QString description() const override { return "chat without tokens"; }
};

View File

@ -28,16 +28,13 @@ class QwenFim : public LLMCore::PromptTemplate
{
public:
QString name() const override { return "Qwen FIM"; }
LLMCore::TemplateType type() const override { return LLMCore::TemplateType::Fim; }
QString promptTemplate() const override
{
return "<|fim_prefix|>%1<|fim_suffix|>%2<|fim_middle|>";
}
LLMCore::TemplateType type() const override { return LLMCore::TemplateType::FIM; }
QStringList stopWords() const override { return QStringList() << "<|endoftext|>" << "<|EOT|>"; }
void prepareRequest(QJsonObject &request, const LLMCore::ContextData &context) const override
{
QString formattedPrompt = promptTemplate().arg(context.prefix, context.suffix);
request["prompt"] = formattedPrompt;
request["prompt"] = QString("<|fim_prefix|>%1<|fim_suffix|>%2<|fim_middle|>")
.arg(context.prefix.value_or(""), context.suffix.value_or(""));
request["system"] = context.systemPrompt.value_or("");
}
QString description() const override
{

View File

@ -26,9 +26,8 @@ namespace QodeAssist::Templates {
class StarCoder2Fim : public LLMCore::PromptTemplate
{
public:
LLMCore::TemplateType type() const override { return LLMCore::TemplateType::Fim; }
LLMCore::TemplateType type() const override { return LLMCore::TemplateType::FIM; }
QString name() const override { return "StarCoder2 FIM"; }
QString promptTemplate() const override { return "<fim_prefix>%1<fim_suffix>%2<fim_middle>"; }
QStringList stopWords() const override
{
return QStringList() << "<|endoftext|>" << "<file_sep>" << "<fim_prefix>" << "<fim_suffix>"
@ -36,8 +35,9 @@ public:
}
void prepareRequest(QJsonObject &request, const LLMCore::ContextData &context) const override
{
QString formattedPrompt = promptTemplate().arg(context.prefix, context.suffix);
request["prompt"] = formattedPrompt;
request["prompt"] = QString("<fim_prefix>%1<fim_suffix>%2<fim_middle>")
.arg(context.prefix.value_or(""), context.suffix.value_or(""));
request["system"] = context.systemPrompt.value_or("");
}
QString description() const override
{

View File

@ -21,17 +21,18 @@
#include "llmcore/PromptTemplateManager.hpp"
#include "templates/Alpaca.hpp"
#include "templates/BasicChat.hpp"
#include "templates/ChatML.hpp"
#include "templates/Claude.hpp"
#include "templates/CodeLlamaFim.hpp"
#include "templates/CodeLlamaQMLFim.hpp"
#include "templates/CustomFimTemplate.hpp"
#include "templates/DeepSeekCoderFim.hpp"
#include "templates/Llama2.hpp"
#include "templates/Llama3.hpp"
#include "templates/MistralAI.hpp"
#include "templates/Ollama.hpp"
#include "templates/OpenAI.hpp"
#include "templates/OpenAICompatible.hpp"
// #include "templates/CustomFimTemplate.hpp"
// #include "templates/DeepSeekCoderFim.hpp"
#include "templates/Llama2.hpp"
#include "templates/Llama3.hpp"
#include "templates/Qwen.hpp"
#include "templates/StarCoder2Fim.hpp"
@ -41,20 +42,22 @@ inline void registerTemplates()
{
auto &templateManager = LLMCore::PromptTemplateManager::instance();
templateManager.registerTemplate<CodeLlamaFim>();
templateManager.registerTemplate<StarCoder2Fim>();
templateManager.registerTemplate<DeepSeekCoderFim>();
templateManager.registerTemplate<CustomTemplate>();
templateManager.registerTemplate<QwenFim>();
templateManager.registerTemplate<OllamaAutoFim>();
templateManager.registerTemplate<OllamaAutoChat>();
templateManager.registerTemplate<BasicChat>();
templateManager.registerTemplate<Llama3>();
templateManager.registerTemplate<ChatML>();
templateManager.registerTemplate<Alpaca>();
templateManager.registerTemplate<Llama2>();
templateManager.registerTemplate<OllamaFim>();
templateManager.registerTemplate<OllamaChat>();
templateManager.registerTemplate<Claude>();
templateManager.registerTemplate<OpenAI>();
templateManager.registerTemplate<MistralAIFim>();
templateManager.registerTemplate<MistralAIChat>();
templateManager.registerTemplate<CodeLlamaQMLFim>();
templateManager.registerTemplate<ChatML>();
templateManager.registerTemplate<Llama2>();
templateManager.registerTemplate<Llama3>();
templateManager.registerTemplate<StarCoder2Fim>();
// templateManager.registerTemplate<DeepSeekCoderFim>();
// templateManager.registerTemplate<CustomTemplate>();
templateManager.registerTemplate<QwenFim>();
templateManager.registerTemplate<OpenAICompatible>();
templateManager.registerTemplate<Alpaca>();
}
} // namespace QodeAssist::Templates