From 1261f913bbd7e9db49ce3c11cb9adefb68a642a5 Mon Sep 17 00:00:00 2001 From: Petr Mironychev <9195189+Palm1r@users.noreply.github.com> Date: Tue, 26 Nov 2024 00:28:27 +0100 Subject: [PATCH] :recycle: refactor: Rework currents and add new templates Add Alpaca, Llama3, LLama2, ChatML templates --- CMakeLists.txt | 12 +++- ChatView/ClientInterface.cpp | 3 +- llmcore/PromptTemplate.hpp | 1 + providers/Providers.hpp | 37 ++++++++++ qodeassist.cpp | 38 ++--------- templates/Alpaca.hpp | 67 +++++++++++++++++++ templates/BasicChat.hpp | 13 +--- .../{DeepSeekCoderChat.hpp => ChatML.hpp} | 30 +++++---- templates/CodeLlamaFim.hpp | 5 +- templates/CustomFimTemplate.hpp | 2 +- templates/DeepSeekCoderFim.hpp | 5 ++ templates/{CodeLlamaChat.hpp => Llama2.hpp} | 44 +++++++----- templates/{StarCoderChat.hpp => Llama3.hpp} | 31 ++++++--- templates/Ollama.hpp | 3 +- templates/Qwen.hpp | 32 ++------- templates/StarCoder2Fim.hpp | 5 ++ templates/Templates.hpp | 54 +++++++++++++++ 17 files changed, 263 insertions(+), 119 deletions(-) create mode 100644 providers/Providers.hpp create mode 100644 templates/Alpaca.hpp rename templates/{DeepSeekCoderChat.hpp => ChatML.hpp} (62%) rename templates/{CodeLlamaChat.hpp => Llama2.hpp} (54%) rename templates/{StarCoderChat.hpp => Llama3.hpp} (59%) create mode 100644 templates/Templates.hpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 45a6620..f09abd2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -40,16 +40,22 @@ add_qtc_plugin(QodeAssist QodeAssistConstants.hpp QodeAssisttr.h LLMClientInterface.hpp LLMClientInterface.cpp + templates/Templates.hpp templates/CodeLlamaFim.hpp templates/StarCoder2Fim.hpp templates/DeepSeekCoderFim.hpp templates/CustomFimTemplate.hpp - templates/DeepSeekCoderChat.hpp - templates/CodeLlamaChat.hpp + + templates/Qwen.hpp - templates/StarCoderChat.hpp + templates/Ollama.hpp templates/BasicChat.hpp + templates/Llama3.hpp + templates/ChatML.hpp + templates/Alpaca.hpp + templates/Llama2.hpp + providers/Providers.hpp providers/OllamaProvider.hpp providers/OllamaProvider.cpp providers/LMStudioProvider.hpp providers/LMStudioProvider.cpp providers/OpenAICompatProvider.hpp providers/OpenAICompatProvider.cpp diff --git a/ChatView/ClientInterface.cpp b/ChatView/ClientInterface.cpp index 8b68819..3ec44b8 100644 --- a/ChatView/ClientInterface.cpp +++ b/ChatView/ClientInterface.cpp @@ -68,6 +68,8 @@ void ClientInterface::sendMessage(const QString &message, bool includeCurrentFil { cancelRequest(); + m_chatModel->addMessage(message, ChatModel::ChatRole::User, ""); + auto &chatAssistantSettings = Settings::chatAssistantSettings(); auto providerName = Settings::generalSettings().caProvider(); @@ -128,7 +130,6 @@ void ClientInterface::sendMessage(const QString &message, bool includeCurrentFil QJsonObject request; request["id"] = QUuid::createUuid().toString(); - m_chatModel->addMessage(message, ChatModel::ChatRole::User, ""); m_requestHandler->sendLLMRequest(config, request); } diff --git a/llmcore/PromptTemplate.hpp b/llmcore/PromptTemplate.hpp index 3952810..5d7ff2b 100644 --- a/llmcore/PromptTemplate.hpp +++ b/llmcore/PromptTemplate.hpp @@ -38,5 +38,6 @@ public: virtual QString promptTemplate() const = 0; virtual QStringList stopWords() const = 0; virtual void prepareRequest(QJsonObject &request, const ContextData &context) const = 0; + virtual QString description() const = 0; }; } // namespace QodeAssist::LLMCore diff --git a/providers/Providers.hpp b/providers/Providers.hpp new file mode 100644 index 0000000..7bb614f --- /dev/null +++ b/providers/Providers.hpp @@ -0,0 +1,37 @@ +/* + * Copyright (C) 2024 Petr Mironychev + * + * This file is part of QodeAssist. + * + * QodeAssist is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * QodeAssist is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with QodeAssist. If not, see . + */ + +#pragma once + +#include "llmcore/ProvidersManager.hpp" +#include "providers/LMStudioProvider.hpp" +#include "providers/OllamaProvider.hpp" +#include "providers/OpenAICompatProvider.hpp" + +namespace QodeAssist::Providers { + +inline void registerProviders() +{ + auto &providerManager = LLMCore::ProvidersManager::instance(); + providerManager.registerProvider(); + providerManager.registerProvider(); + providerManager.registerProvider(); +} + +} // namespace QodeAssist::Providers diff --git a/qodeassist.cpp b/qodeassist.cpp index c132a68..9c8d5e8 100644 --- a/qodeassist.cpp +++ b/qodeassist.cpp @@ -43,22 +43,9 @@ #include "QodeAssistClient.hpp" #include "chat/ChatOutputPane.h" #include "chat/NavigationPanel.hpp" -#include "llmcore/PromptTemplateManager.hpp" -#include "llmcore/ProvidersManager.hpp" -#include "providers/LMStudioProvider.hpp" -#include "providers/OllamaProvider.hpp" -#include "providers/OpenAICompatProvider.hpp" -#include "templates/BasicChat.hpp" -#include "templates/CodeLlamaChat.hpp" -#include "templates/CodeLlamaFim.hpp" -#include "templates/CustomFimTemplate.hpp" -#include "templates/DeepSeekCoderChat.hpp" -#include "templates/DeepSeekCoderFim.hpp" -#include "templates/Ollama.hpp" -#include "templates/Qwen.hpp" -#include "templates/StarCoder2Fim.hpp" -#include "templates/StarCoderChat.hpp" +#include "providers/Providers.hpp" +#include "templates/Templates.hpp" using namespace Utils; using namespace Core; @@ -85,25 +72,8 @@ public: void initialize() final { - auto &providerManager = LLMCore::ProvidersManager::instance(); - providerManager.registerProvider(); - providerManager.registerProvider(); - providerManager.registerProvider(); - - auto &templateManager = LLMCore::PromptTemplateManager::instance(); - templateManager.registerTemplate(); - templateManager.registerTemplate(); - templateManager.registerTemplate(); - templateManager.registerTemplate(); - templateManager.registerTemplate(); - templateManager.registerTemplate(); - templateManager.registerTemplate(); - templateManager.registerTemplate(); - templateManager.registerTemplate(); - templateManager.registerTemplate(); - templateManager.registerTemplate(); - templateManager.registerTemplate(); - templateManager.registerTemplate(); + Providers::registerProviders(); + Templates::registerTemplates(); Utils::Icon QCODEASSIST_ICON( {{":/resources/images/qoderassist-icon.png", Utils::Theme::IconsBaseColor}}); diff --git a/templates/Alpaca.hpp b/templates/Alpaca.hpp new file mode 100644 index 0000000..abc811b --- /dev/null +++ b/templates/Alpaca.hpp @@ -0,0 +1,67 @@ +/* + * Copyright (C) 2024 Petr Mironychev + * + * This file is part of QodeAssist. + * + * QodeAssist is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * QodeAssist is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with QodeAssist. If not, see . + */ + +#pragma once + +#include "llmcore/PromptTemplate.hpp" +#include + +namespace QodeAssist::Templates { + +class Alpaca : public LLMCore::PromptTemplate +{ +public: + QString name() const override { return "Alpaca"; } + LLMCore::TemplateType type() const override { return LLMCore::TemplateType::Chat; } + QString promptTemplate() const override { return {}; } + QStringList stopWords() const override + { + return QStringList() << "### Instruction:" << "### Response:"; + } + void prepareRequest(QJsonObject &request, const LLMCore::ContextData &context) const override + { + QJsonArray messages = request["messages"].toArray(); + + for (int i = 0; i < messages.size(); ++i) { + QJsonObject message = messages[i].toObject(); + QString role = message["role"].toString(); + QString content = message["content"].toString(); + + QString formattedContent; + if (role == "system") { + formattedContent = content + "\n\n"; + } else if (role == "user") { + formattedContent = "### Instruction:\n" + content + "\n\n"; + } else if (role == "assistant") { + formattedContent = "### Response:\n" + content + "\n\n"; + } + + message["content"] = formattedContent; + messages[i] = message; + } + + request["messages"] = messages; + } + QString description() const override + { + return "The message will contain the following tokens: ### Instruction:\n### Response:\n"; + } +}; + +} // namespace QodeAssist::Templates diff --git a/templates/BasicChat.hpp b/templates/BasicChat.hpp index 1bb6b31..7e552b8 100644 --- a/templates/BasicChat.hpp +++ b/templates/BasicChat.hpp @@ -32,18 +32,9 @@ public: QString name() const override { return "Basic Chat"; } QString promptTemplate() const override { return {}; } QStringList stopWords() const override { return QStringList(); } - void prepareRequest(QJsonObject &request, const LLMCore::ContextData &context) const override - { - QJsonArray messages = request["messages"].toArray(); - - QJsonObject newMessage; - newMessage["role"] = "user"; - newMessage["content"] = context.prefix; - messages.append(newMessage); - - request["messages"] = messages; - } + {} + QString description() const override { return "chat without tokens"; } }; } // namespace QodeAssist::Templates diff --git a/templates/DeepSeekCoderChat.hpp b/templates/ChatML.hpp similarity index 62% rename from templates/DeepSeekCoderChat.hpp rename to templates/ChatML.hpp index 44e8922..bb125cf 100644 --- a/templates/DeepSeekCoderChat.hpp +++ b/templates/ChatML.hpp @@ -20,35 +20,41 @@ #pragma once #include + #include "llmcore/PromptTemplate.hpp" namespace QodeAssist::Templates { -class DeepSeekCoderChat : public LLMCore::PromptTemplate +class ChatML : public LLMCore::PromptTemplate { public: - QString name() const override { return "DeepSeekCoder Chat"; } + QString name() const override { return "ChatML"; } LLMCore::TemplateType type() const override { return LLMCore::TemplateType::Chat; } - - QString promptTemplate() const override { return "### Instruction:\n%1\n### Response:\n"; } - + QString promptTemplate() const override { return {}; } QStringList stopWords() const override { - return QStringList() << "### Instruction:" << "### Response:" << "\n\n### " << "<|EOT|>"; + return QStringList() << "<|im_start|>" << "<|im_end|>"; } - void prepareRequest(QJsonObject &request, const LLMCore::ContextData &context) const override { - QString formattedPrompt = promptTemplate().arg(context.prefix); QJsonArray messages = request["messages"].toArray(); - QJsonObject newMessage; - newMessage["role"] = "user"; - newMessage["content"] = formattedPrompt; - messages.append(newMessage); + for (int i = 0; i < messages.size(); ++i) { + QJsonObject message = messages[i].toObject(); + QString role = message["role"].toString(); + QString content = message["content"].toString(); + + message["content"] = QString("<|im_start|>%1\n%2\n<|im_end|>").arg(role, content); + + messages[i] = message; + } request["messages"] = messages; } + QString description() const override + { + return "The message will contain the following tokens: <|im_start|>%1\n%2\n<|im_end|>"; + } }; } // namespace QodeAssist::Templates diff --git a/templates/CodeLlamaFim.hpp b/templates/CodeLlamaFim.hpp index 7f6945f..36f2e2c 100644 --- a/templates/CodeLlamaFim.hpp +++ b/templates/CodeLlamaFim.hpp @@ -33,12 +33,15 @@ public: { return QStringList() << "" << "
" << "";
     }
-
     void prepareRequest(QJsonObject &request, const LLMCore::ContextData &context) const override
     {
         QString formattedPrompt = promptTemplate().arg(context.prefix, context.suffix);
         request["prompt"] = formattedPrompt;
     }
+    QString description() const override
+    {
+        return "The message will contain the following tokens: 
 %1 %2 ";
+    }
 };
 
 } // namespace QodeAssist::Templates
diff --git a/templates/CustomFimTemplate.hpp b/templates/CustomFimTemplate.hpp
index 9180c32..5b5fd83 100644
--- a/templates/CustomFimTemplate.hpp
+++ b/templates/CustomFimTemplate.hpp
@@ -39,7 +39,6 @@ public:
         return Settings::customPromptSettings().customJsonTemplate();
     }
     QStringList stopWords() const override { return QStringList(); }
-
     void prepareRequest(QJsonObject &request, const LLMCore::ContextData &context) const override
     {
         QJsonDocument doc = QJsonDocument::fromJson(promptTemplate().toUtf8());
@@ -56,6 +55,7 @@ public:
             request[it.key()] = it.value();
         }
     }
+    QString description() const override { return promptTemplate(); }
 
 private:
     QJsonValue processJsonValue(const QJsonValue &value, const LLMCore::ContextData &context) const
diff --git a/templates/DeepSeekCoderFim.hpp b/templates/DeepSeekCoderFim.hpp
index d7aa98f..8bfdd0a 100644
--- a/templates/DeepSeekCoderFim.hpp
+++ b/templates/DeepSeekCoderFim.hpp
@@ -38,6 +38,11 @@ public:
         QString formattedPrompt = promptTemplate().arg(context.prefix, context.suffix);
         request["prompt"] = formattedPrompt;
     }
+    QString description() const override
+    {
+        return "The message will contain the following tokens: "
+               "<|fim▁begin|>%1<|fim▁hole|>%2<|fim▁end|>";
+    }
 };
 
 } // namespace QodeAssist::Templates
diff --git a/templates/CodeLlamaChat.hpp b/templates/Llama2.hpp
similarity index 54%
rename from templates/CodeLlamaChat.hpp
rename to templates/Llama2.hpp
index 0a6b713..6888f7b 100644
--- a/templates/CodeLlamaChat.hpp
+++ b/templates/Llama2.hpp
@@ -19,38 +19,46 @@
 
 #pragma once
 
-#include 
-
 #include "llmcore/PromptTemplate.hpp"
+#include 
 
 namespace QodeAssist::Templates {
 
-class CodeLlamaChat : public LLMCore::PromptTemplate
+class Llama2 : public LLMCore::PromptTemplate
 {
 public:
+    QString name() const override { return "Llama 2"; }
     LLMCore::TemplateType type() const override { return LLMCore::TemplateType::Chat; }
-    QString name() const override { return "CodeLlama Chat"; }
-    QString promptTemplate() const override { return "[INST] %1 [/INST]"; }
-    QStringList stopWords() const override { return QStringList() << "[INST]" << "[/INST]"; }
-
+    QString promptTemplate() const override { return {}; }
+    QStringList stopWords() const override { return QStringList() << "[INST]"; }
     void prepareRequest(QJsonObject &request, const LLMCore::ContextData &context) const override
     {
-        QString formattedPrompt = promptTemplate().arg(context.prefix);
         QJsonArray messages = request["messages"].toArray();
 
-        QJsonObject newMessage;
-        newMessage["role"] = "user";
-        newMessage["content"] = formattedPrompt;
-        messages.append(newMessage);
+        for (int i = 0; i < messages.size(); ++i) {
+            QJsonObject message = messages[i].toObject();
+            QString role = message["role"].toString();
+            QString content = message["content"].toString();
+
+            QString formattedContent;
+            if (role == "system") {
+                formattedContent = QString("[INST]<>\n%1\n<>[/INST]\n").arg(content);
+            } else if (role == "user") {
+                formattedContent = QString("[INST]%1[/INST]\n").arg(content);
+            } else if (role == "assistant") {
+                formattedContent = content + "\n";
+            }
+
+            message["content"] = formattedContent;
+            messages[i] = message;
+        }
 
         request["messages"] = messages;
     }
-};
-
-class LlamaChat : public CodeLlamaChat
-{
-public:
-    QString name() const override { return "Llama Chat"; }
+    QString description() const override
+    {
+        return "The message will contain the following tokens: [INST]%1[/INST]\n";
+    }
 };
 
 } // namespace QodeAssist::Templates
diff --git a/templates/StarCoderChat.hpp b/templates/Llama3.hpp
similarity index 59%
rename from templates/StarCoderChat.hpp
rename to templates/Llama3.hpp
index 437f53e..7b3bf2b 100644
--- a/templates/StarCoderChat.hpp
+++ b/templates/Llama3.hpp
@@ -20,32 +20,43 @@
 #pragma once
 
 #include 
+
 #include "llmcore/PromptTemplate.hpp"
 
 namespace QodeAssist::Templates {
 
-class StarCoderChat : public LLMCore::PromptTemplate
+class Llama3 : public LLMCore::PromptTemplate
 {
 public:
-    QString name() const override { return "StarCoder Chat"; }
+    QString name() const override { return "Llama 3"; }
     LLMCore::TemplateType type() const override { return LLMCore::TemplateType::Chat; }
-    QString promptTemplate() const override { return "### Instruction:\n%1\n### Response:\n"; }
+    QString promptTemplate() const override { return ""; }
     QStringList stopWords() const override
     {
-        return QStringList() << "###"
-                             << "<|endoftext|>" << "";
+        return QStringList() << "<|start_header_id|>" << "<|end_header_id|>" << "<|eot_id|>";
     }
     void prepareRequest(QJsonObject &request, const LLMCore::ContextData &context) const override
     {
-        QString formattedPrompt = promptTemplate().arg(context.prefix);
         QJsonArray messages = request["messages"].toArray();
 
-        QJsonObject newMessage;
-        newMessage["role"] = "user";
-        newMessage["content"] = formattedPrompt;
-        messages.append(newMessage);
+        for (int i = 0; i < messages.size(); ++i) {
+            QJsonObject message = messages[i].toObject();
+            QString role = message["role"].toString();
+            QString content = message["content"].toString();
+
+            message["content"]
+                = QString("<|start_header_id|>%1<|end_header_id|>%2<|eot_id|>").arg(role, content);
+
+            messages[i] = message;
+        }
 
         request["messages"] = messages;
     }
+    QString description() const override
+    {
+        return "The message will contain the following tokens: "
+               "<|start_header_id|>%1<|end_header_id|>%2<|eot_id|>";
+    }
 };
+
 } // namespace QodeAssist::Templates
diff --git a/templates/Ollama.hpp b/templates/Ollama.hpp
index e215c66..b434805 100644
--- a/templates/Ollama.hpp
+++ b/templates/Ollama.hpp
@@ -32,12 +32,12 @@ public:
     QString name() const override { return "Ollama Auto FIM"; }
     QString promptTemplate() const override { return {}; }
     QStringList stopWords() const override { return QStringList(); }
-
     void prepareRequest(QJsonObject &request, const LLMCore::ContextData &context) const override
     {
         request["prompt"] = context.prefix;
         request["suffix"] = context.suffix;
     }
+    QString description() const override { return "template will take from ollama modelfile"; }
 };
 
 class OllamaAutoChat : public LLMCore::PromptTemplate
@@ -59,6 +59,7 @@ public:
 
         request["messages"] = messages;
     }
+    QString description() const override { return "template will take from ollama modelfile"; }
 };
 
 } // namespace QodeAssist::Templates
diff --git a/templates/Qwen.hpp b/templates/Qwen.hpp
index cf36629..1231a74 100644
--- a/templates/Qwen.hpp
+++ b/templates/Qwen.hpp
@@ -24,33 +24,6 @@
 
 namespace QodeAssist::Templates {
 
-class QwenChat : public LLMCore::PromptTemplate
-{
-public:
-    QString name() const override { return "Qwen Chat"; }
-    LLMCore::TemplateType type() const override { return LLMCore::TemplateType::Chat; }
-
-    QString promptTemplate() const override { return "### Instruction:\n%1\n### Response:\n"; }
-
-    QStringList stopWords() const override
-    {
-        return QStringList() << "### Instruction:" << "### Response:" << "\n\n### " << "<|EOT|>";
-    }
-
-    void prepareRequest(QJsonObject &request, const LLMCore::ContextData &context) const override
-    {
-        QString formattedPrompt = promptTemplate().arg(context.prefix);
-        QJsonArray messages = request["messages"].toArray();
-
-        QJsonObject newMessage;
-        newMessage["role"] = "user";
-        newMessage["content"] = formattedPrompt;
-        messages.append(newMessage);
-
-        request["messages"] = messages;
-    }
-};
-
 class QwenFim : public LLMCore::PromptTemplate
 {
 public:
@@ -66,6 +39,11 @@ public:
         QString formattedPrompt = promptTemplate().arg(context.prefix, context.suffix);
         request["prompt"] = formattedPrompt;
     }
+    QString description() const override
+    {
+        return "The message will contain the following tokens: "
+               "<|fim_prefix|>%1<|fim_suffix|>%2<|fim_middle|>";
+    }
 };
 
 } // namespace QodeAssist::Templates
diff --git a/templates/StarCoder2Fim.hpp b/templates/StarCoder2Fim.hpp
index e0d3ca2..00f68da 100644
--- a/templates/StarCoder2Fim.hpp
+++ b/templates/StarCoder2Fim.hpp
@@ -39,6 +39,11 @@ public:
         QString formattedPrompt = promptTemplate().arg(context.prefix, context.suffix);
         request["prompt"] = formattedPrompt;
     }
+    QString description() const override
+    {
+        return "The message will contain the following tokens: "
+               "%1%2";
+    }
 };
 
 } // namespace QodeAssist::Templates
diff --git a/templates/Templates.hpp b/templates/Templates.hpp
new file mode 100644
index 0000000..098aaef
--- /dev/null
+++ b/templates/Templates.hpp
@@ -0,0 +1,54 @@
+/* 
+ * Copyright (C) 2024 Petr Mironychev
+ *
+ * This file is part of QodeAssist.
+ *
+ * QodeAssist is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * QodeAssist is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with QodeAssist. If not, see .
+ */
+
+#pragma once
+
+#include "llmcore/PromptTemplateManager.hpp"
+#include "templates/Alpaca.hpp"
+#include "templates/BasicChat.hpp"
+#include "templates/ChatML.hpp"
+#include "templates/CodeLlamaFim.hpp"
+#include "templates/CustomFimTemplate.hpp"
+#include "templates/DeepSeekCoderFim.hpp"
+#include "templates/Llama2.hpp"
+#include "templates/Llama3.hpp"
+#include "templates/Ollama.hpp"
+#include "templates/Qwen.hpp"
+#include "templates/StarCoder2Fim.hpp"
+
+namespace QodeAssist::Templates {
+
+inline void registerTemplates()
+{
+    auto &templateManager = LLMCore::PromptTemplateManager::instance();
+    templateManager.registerTemplate();
+    templateManager.registerTemplate();
+    templateManager.registerTemplate();
+    templateManager.registerTemplate();
+    templateManager.registerTemplate();
+    templateManager.registerTemplate();
+    templateManager.registerTemplate();
+    templateManager.registerTemplate();
+    templateManager.registerTemplate();
+    templateManager.registerTemplate();
+    templateManager.registerTemplate();
+    templateManager.registerTemplate();
+}
+
+} // namespace QodeAssist::Templates