diff --git a/CMakeLists.txt b/CMakeLists.txt index 0b5d121..42f8ec3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -36,6 +36,7 @@ add_qtc_plugin(QodeAssist templates/StarCoder2Template.hpp templates/CodeQwenChat.hpp templates/DeepSeekCoderV2.hpp + templates/CustomTemplate.hpp providers/LLMProvider.hpp providers/OllamaProvider.hpp providers/OllamaProvider.cpp providers/LMStudioProvider.hpp providers/LMStudioProvider.cpp diff --git a/QodeAssistConstants.hpp b/QodeAssistConstants.hpp index 4cb9971..57fb905 100644 --- a/QodeAssistConstants.hpp +++ b/QodeAssistConstants.hpp @@ -56,6 +56,7 @@ const char MULTILINE_COMPLETION[] = "QodeAssist.multilineCompletion"; const char API_KEY[] = "QodeAssist.apiKey"; const char USE_SPECIFIC_INSTRUCTIONS[] = "QodeAssist.useSpecificInstructions"; const char USE_FILE_PATH_IN_CONTEXT[] = "QodeAssist.useFilePathInContext"; +const char CUSTOM_JSON_TEMPLATE[] = "QodeAssist.customJsonTemplate"; const char QODE_ASSIST_GENERAL_OPTIONS_ID[] = "QodeAssist.GeneralOptions"; const char QODE_ASSIST_GENERAL_OPTIONS_CATEGORY[] = "QodeAssist.Category"; diff --git a/QodeAssistSettings.cpp b/QodeAssistSettings.cpp index 188441e..f2c1aab 100644 --- a/QodeAssistSettings.cpp +++ b/QodeAssistSettings.cpp @@ -176,6 +176,29 @@ QodeAssistSettings::QodeAssistSettings() apiKey.setDisplayStyle(Utils::StringAspect::LineEditDisplay); apiKey.setPlaceHolderText(Tr::tr("Enter your API key here")); + customJsonTemplate.setSettingsKey(Constants::CUSTOM_JSON_TEMPLATE); + customJsonTemplate.setLabelText("Custom JSON Template:"); + customJsonTemplate.setDisplayStyle(Utils::StringAspect::TextEditDisplay); + customJsonTemplate.setDefaultValue(R"({ + "prompt": "{{QODE_INSTRUCTIONS}}{{QODE_PREFIX}}{{QODE_SUFFIX}}", + "options": { + "temperature": 0.7, + "top_p": 0.95, + "top_k": 40, + "num_predict": 100, + "stop": [ + "<|endoftext|>", + "", + "", + "", + "" + ], + "frequency_penalty": 0, + "presence_penalty": 0 + }, + "stream": true +})"); + const auto &manager = LLMProvidersManager::instance(); if (!manager.getProviderNames().isEmpty()) { const auto providerNames = manager.getProviderNames(); @@ -203,6 +226,8 @@ QodeAssistSettings::QodeAssistSettings() specificInstractions.setEnabled(useSpecificInstructions()); PromptTemplateManager::instance().setCurrentTemplate(fimPrompts.stringValue()); LLMProvidersManager::instance().setCurrentProvider(llmProviders.stringValue()); + customJsonTemplate.setVisible(PromptTemplateManager::instance().getCurrentTemplate()->name() + == "Custom Template"); setLoggingEnabled(enableLogging()); @@ -221,6 +246,7 @@ QodeAssistSettings::QodeAssistSettings() Form{Column{Row{selectModels, modelName}}}}, Group{title(Tr::tr("FIM Prompt Settings")), Form{Column{fimPrompts, + Row{customJsonTemplate, Space{40}}, readFullFile, maxFileThreshold, readStringsBeforeCursor, @@ -256,6 +282,7 @@ void QodeAssistSettings::setupConnections() int index = fimPrompts.volatileValue(); logMessage(QString("currentPrompt %1").arg(fimPrompts.displayForIndex(index))); PromptTemplateManager::instance().setCurrentTemplate(fimPrompts.displayForIndex(index)); + customJsonTemplate.setVisible(fimPrompts.displayForIndex(index) == "Custom Template"); }); connect(&selectModels, &ButtonAspect::clicked, this, [this]() { showModelSelectionDialog(); }); @@ -366,6 +393,7 @@ void QodeAssistSettings::resetSettingsToDefaults() resetAspect(multiLineCompletion); resetAspect(useFilePathInContext); resetAspect(useSpecificInstructions); + resetAspect(customJsonTemplate); fimPrompts.setStringValue("StarCoder2"); llmProviders.setStringValue("Ollama"); diff --git a/QodeAssistSettings.hpp b/QodeAssistSettings.hpp index 641f9d2..5ea788c 100644 --- a/QodeAssistSettings.hpp +++ b/QodeAssistSettings.hpp @@ -97,6 +97,7 @@ public: Utils::BoolAspect useFilePathInContext{this}; Utils::BoolAspect multiLineCompletion{this}; + Utils::StringAspect customJsonTemplate{this}; Utils::StringAspect apiKey{this}; ButtonAspect resetToDefaults{this}; diff --git a/providers/LMStudioProvider.cpp b/providers/LMStudioProvider.cpp index 4aa5df3..ca00925 100644 --- a/providers/LMStudioProvider.cpp +++ b/providers/LMStudioProvider.cpp @@ -51,7 +51,8 @@ QString LMStudioProvider::completionEndpoint() const void LMStudioProvider::prepareRequest(QJsonObject &request) { const auto ¤tTemplate = PromptTemplateManager::instance().getCurrentTemplate(); - + if (currentTemplate->name() == "Custom Template") + return; if (request.contains("prompt")) { QJsonArray messages{ {QJsonObject{{"role", "user"}, {"content", request.take("prompt").toString()}}}}; diff --git a/providers/OllamaProvider.cpp b/providers/OllamaProvider.cpp index 41aaaa7..9bbbf29 100644 --- a/providers/OllamaProvider.cpp +++ b/providers/OllamaProvider.cpp @@ -51,6 +51,8 @@ QString OllamaProvider::completionEndpoint() const void OllamaProvider::prepareRequest(QJsonObject &request) { auto currentTemplate = PromptTemplateManager::instance().getCurrentTemplate(); + if (currentTemplate->name() == "Custom Template") + return; QJsonObject options; options["num_predict"] = settings().maxTokens(); diff --git a/providers/OpenAICompatProvider.cpp b/providers/OpenAICompatProvider.cpp index 64d0b3a..b3641e1 100644 --- a/providers/OpenAICompatProvider.cpp +++ b/providers/OpenAICompatProvider.cpp @@ -49,6 +49,8 @@ QString OpenAICompatProvider::completionEndpoint() const void OpenAICompatProvider::prepareRequest(QJsonObject &request) { const auto ¤tTemplate = PromptTemplateManager::instance().getCurrentTemplate(); + if (currentTemplate->name() == "Custom Template") + return; if (request.contains("prompt")) { QJsonArray messages{ diff --git a/qodeassist.cpp b/qodeassist.cpp index ffcd0f6..43ed118 100644 --- a/qodeassist.cpp +++ b/qodeassist.cpp @@ -46,6 +46,7 @@ #include "providers/OpenAICompatProvider.hpp" #include "templates/CodeLLamaTemplate.hpp" #include "templates/CodeQwenChat.hpp" +#include "templates/CustomTemplate.hpp" #include "templates/DeepSeekCoderV2.hpp" #include "templates/StarCoder2Template.hpp" @@ -82,6 +83,7 @@ public: templateManager.registerTemplate(); templateManager.registerTemplate(); templateManager.registerTemplate(); + templateManager.registerTemplate(); Utils::Icon QCODEASSIST_ICON( {{":/resources/images/qoderassist-icon.png", Utils::Theme::IconsBaseColor}}); diff --git a/templates/CustomTemplate.hpp b/templates/CustomTemplate.hpp new file mode 100644 index 0000000..ac4173c --- /dev/null +++ b/templates/CustomTemplate.hpp @@ -0,0 +1,86 @@ +/* + * Copyright (C) 2024 Petr Mironychev + * + * This file is part of QodeAssist. + * + * QodeAssist is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * QodeAssist is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with QodeAssist. If not, see . + */ + +#pragma once + +#include "PromptTemplate.hpp" + +#include +#include + +#include "QodeAssistSettings.hpp" +#include "QodeAssistUtils.hpp" + +namespace QodeAssist::Templates { + +class CustomTemplate : public PromptTemplate +{ +public: + QString name() const override { return "Custom Template"; } + QString promptTemplate() const override { return settings().customJsonTemplate(); } + QStringList stopWords() const override { return QStringList(); } + + void prepareRequest(QJsonObject &request, const ContextData &context) const override + { + QJsonDocument doc = QJsonDocument::fromJson(promptTemplate().toUtf8()); + if (doc.isNull() || !doc.isObject()) { + logMessage(QString("Invalid JSON template in settings")); + + return; + } + + QJsonObject templateObj = doc.object(); + QJsonObject processedObj = processJsonTemplate(templateObj, context); + + for (auto it = processedObj.begin(); it != processedObj.end(); ++it) { + request[it.key()] = it.value(); + } + } + +private: + QJsonValue processJsonValue(const QJsonValue &value, const ContextData &context) const + { + if (value.isString()) { + QString str = value.toString(); + str.replace("{{QODE_INSTRUCTIONS}}", context.instriuctions); + str.replace("{{QODE_PREFIX}}", context.prefix); + str.replace("{{QODE_SUFFIX}}", context.suffix); + return str; + } else if (value.isObject()) { + return processJsonTemplate(value.toObject(), context); + } else if (value.isArray()) { + QJsonArray newArray; + for (const QJsonValue &arrayValue : value.toArray()) { + newArray.append(processJsonValue(arrayValue, context)); + } + return newArray; + } + return value; + } + + QJsonObject processJsonTemplate(const QJsonObject &templateObj, const ContextData &context) const + { + QJsonObject result; + for (auto it = templateObj.begin(); it != templateObj.end(); ++it) { + result[it.key()] = processJsonValue(it.value(), context); + } + return result; + } +}; +} // namespace QodeAssist::Templates