From 7ba615a72db4d791e2113d444647600d0068e488 Mon Sep 17 00:00:00 2001 From: Petr Mironychev <9195189+Palm1r@users.noreply.github.com> Date: Tue, 25 Feb 2025 14:14:48 +0100 Subject: [PATCH] feat: Add Google AI provider and template --- CMakeLists.txt | 6 + ChatView/ClientInterface.cpp | 20 ++- LLMClientInterface.cpp | 54 ++---- LLMClientInterface.hpp | 2 - context/ContextManager.cpp | 35 ++++ context/ContextManager.hpp | 3 + llmcore/Provider.hpp | 1 - llmcore/ProviderID.hpp | 4 +- providers/GoogleAIProvider.cpp | 303 +++++++++++++++++++++++++++++++++ providers/GoogleAIProvider.hpp | 51 ++++++ providers/Providers.hpp | 2 + settings/ProviderSettings.cpp | 13 ++ settings/ProviderSettings.hpp | 1 + settings/SettingsConstants.hpp | 2 + templates/GoogleAI.hpp | 72 ++++++++ templates/Templates.hpp | 2 + 16 files changed, 524 insertions(+), 47 deletions(-) create mode 100644 providers/GoogleAIProvider.cpp create mode 100644 providers/GoogleAIProvider.hpp create mode 100644 templates/GoogleAI.hpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 5de4aa3..d028036 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -75,3 +75,9 @@ add_qtc_plugin(QodeAssist CodeHandler.hpp CodeHandler.cpp UpdateStatusWidget.hpp UpdateStatusWidget.cpp ) + +target_sources(QodeAssist + PRIVATE + providers/GoogleAIProvider.hpp providers/GoogleAIProvider.cpp + templates/GoogleAI.hpp +) diff --git a/ChatView/ClientInterface.cpp b/ChatView/ClientInterface.cpp index c264fb0..52667a2 100644 --- a/ChatView/ClientInterface.cpp +++ b/ChatView/ClientInterface.cpp @@ -112,11 +112,23 @@ void ClientInterface::sendMessage( config.requestType = LLMCore::RequestType::Chat; config.provider = provider; config.promptTemplate = promptTemplate; - config.url = QString("%1%2").arg(Settings::generalSettings().caUrl(), provider->chatEndpoint()); + if (provider->providerID() == LLMCore::ProviderID::GoogleAI) { + QString stream = chatAssistantSettings.stream() ? QString{"streamGenerateContent?alt=sse"} + : QString{"generateContent?"}; + config.url = QUrl(QString("%1/models/%2:%3") + .arg( + Settings::generalSettings().caUrl(), + Settings::generalSettings().caModel(), + stream)); + } else { + config.url + = QString("%1%2").arg(Settings::generalSettings().caUrl(), provider->chatEndpoint()); + config.providerRequest + = {{"model", Settings::generalSettings().caModel()}, + {"stream", chatAssistantSettings.stream()}}; + } + config.apiKey = provider->apiKey(); - config.providerRequest - = {{"model", Settings::generalSettings().caModel()}, - {"stream", chatAssistantSettings.stream()}}; config.provider ->prepareRequest(config.providerRequest, promptTemplate, context, LLMCore::RequestType::Chat); diff --git a/LLMClientInterface.cpp b/LLMClientInterface.cpp index 7602773..0784ba9 100644 --- a/LLMClientInterface.cpp +++ b/LLMClientInterface.cpp @@ -27,6 +27,7 @@ #include #include "CodeHandler.hpp" +#include "context/ContextManager.hpp" #include "context/DocumentContextReader.hpp" #include "llmcore/PromptTemplateManager.hpp" #include "llmcore/ProvidersManager.hpp" @@ -145,24 +146,13 @@ void LLMClientInterface::handleExit(const QJsonObject &request) emit finished(); } -bool QodeAssist::LLMClientInterface::isSpecifyCompletion(const QJsonObject &request) -{ - auto &generalSettings = Settings::generalSettings(); - - Context::ProgrammingLanguage documentLanguage = getDocumentLanguage(request); - Context::ProgrammingLanguage preset1Language = Context::ProgrammingLanguageUtils::fromString( - generalSettings.preset1Language.displayForIndex(generalSettings.preset1Language())); - - return generalSettings.specifyPreset1() && documentLanguage == preset1Language; -} - void LLMClientInterface::handleCompletion(const QJsonObject &request) { auto updatedContext = prepareContext(request); auto &completeSettings = Settings::codeCompletionSettings(); auto &generalSettings = Settings::generalSettings(); - bool isPreset1Active = isSpecifyCompletion(request); + bool isPreset1Active = Context::ContextManager::instance().isSpecifyCompletion(request); const auto providerName = !isPreset1Active ? generalSettings.ccProvider() : generalSettings.ccPreset1Provider(); @@ -193,14 +183,19 @@ void LLMClientInterface::handleCompletion(const QJsonObject &request) config.requestType = LLMCore::RequestType::CodeCompletion; config.provider = provider; config.promptTemplate = promptTemplate; - config.url = QUrl(QString("%1%2").arg( - url, - promptTemplate->type() == LLMCore::TemplateType::FIM ? provider->completionEndpoint() - : provider->chatEndpoint())); + // TODO refactor networking + if (provider->providerID() == LLMCore::ProviderID::GoogleAI) { + QString stream = completeSettings.stream() ? QString{"streamGenerateContent?alt=sse"} + : QString{"generateContent?"}; + config.url = QUrl(QString("%1/models/%2:%3").arg(url, modelName, stream)); + } else { + config.url = QUrl(QString("%1%2").arg( + url, + promptTemplate->type() == LLMCore::TemplateType::FIM ? provider->completionEndpoint() + : provider->chatEndpoint())); + config.providerRequest = {{"model", modelName}, {"stream", completeSettings.stream()}}; + } config.apiKey = provider->apiKey(); - - config.providerRequest = {{"model", modelName}, {"stream", completeSettings.stream()}}; - config.multiLineCompletion = completeSettings.multiLineCompletion(); const auto stopWords = QJsonArray::fromStringList(config.promptTemplate->stopWords()); @@ -224,6 +219,7 @@ void LLMClientInterface::handleCompletion(const QJsonObject &request) userMessage = updatedContext.prefix.value_or("") + updatedContext.suffix.value_or(""); } + // TODO refactor add message QVector messages; messages.append({"user", userMessage}); updatedContext.history = messages; @@ -268,29 +264,11 @@ LLMCore::ContextData LLMClientInterface::prepareContext(const QJsonObject &reque return reader.prepareContext(lineNumber, cursorPosition); } -Context::ProgrammingLanguage LLMClientInterface::getDocumentLanguage(const QJsonObject &request) const -{ - QJsonObject params = request["params"].toObject(); - QJsonObject doc = params["doc"].toObject(); - QString uri = doc["uri"].toString(); - - Utils::FilePath filePath = Utils::FilePath::fromString(QUrl(uri).toLocalFile()); - TextEditor::TextDocument *textDocument = TextEditor::TextDocument::textDocumentForFilePath( - filePath); - - if (!textDocument) { - LOG_MESSAGE("Error: Document is not available for" + filePath.toString()); - return Context::ProgrammingLanguage::Unknown; - } - - return Context::ProgrammingLanguageUtils::fromMimeType(textDocument->mimeType()); -} - void LLMClientInterface::sendCompletionToClient(const QString &completion, const QJsonObject &request, bool isComplete) { - bool isPreset1Active = isSpecifyCompletion(request); + bool isPreset1Active = Context::ContextManager::instance().isSpecifyCompletion(request); auto templateName = !isPreset1Active ? Settings::generalSettings().ccTemplate() : Settings::generalSettings().ccPreset1Template(); diff --git a/LLMClientInterface.hpp b/LLMClientInterface.hpp index 624b60e..0f51c9e 100644 --- a/LLMClientInterface.hpp +++ b/LLMClientInterface.hpp @@ -61,8 +61,6 @@ private: LLMCore::ContextData prepareContext( const QJsonObject &request, const QStringView &accumulatedCompletion = QString{}); - Context::ProgrammingLanguage getDocumentLanguage(const QJsonObject &request) const; - bool isSpecifyCompletion(const QJsonObject &request); LLMCore::RequestHandler m_requestHandler; QElapsedTimer m_completionTimer; diff --git a/context/ContextManager.cpp b/context/ContextManager.cpp index cce1aba..3425e6d 100644 --- a/context/ContextManager.cpp +++ b/context/ContextManager.cpp @@ -21,8 +21,14 @@ #include #include +#include #include +#include "GeneralSettings.hpp" +#include "Logger.hpp" +#include +#include + namespace QodeAssist::Context { ContextManager &ContextManager::instance() @@ -64,4 +70,33 @@ ContentFile ContextManager::createContentFile(const QString &filePath) const return contentFile; } +ProgrammingLanguage ContextManager::getDocumentLanguage(const QJsonObject &request) const +{ + QJsonObject params = request["params"].toObject(); + QJsonObject doc = params["doc"].toObject(); + QString uri = doc["uri"].toString(); + + Utils::FilePath filePath = Utils::FilePath::fromString(QUrl(uri).toLocalFile()); + TextEditor::TextDocument *textDocument = TextEditor::TextDocument::textDocumentForFilePath( + filePath); + + if (!textDocument) { + LOG_MESSAGE("Error: Document is not available for" + filePath.toString()); + return Context::ProgrammingLanguage::Unknown; + } + + return Context::ProgrammingLanguageUtils::fromMimeType(textDocument->mimeType()); +} + +bool ContextManager::isSpecifyCompletion(const QJsonObject &request) +{ + auto &generalSettings = Settings::generalSettings(); + + Context::ProgrammingLanguage documentLanguage = getDocumentLanguage(request); + Context::ProgrammingLanguage preset1Language = Context::ProgrammingLanguageUtils::fromString( + generalSettings.preset1Language.displayForIndex(generalSettings.preset1Language())); + + return generalSettings.specifyPreset1() && documentLanguage == preset1Language; +} + } // namespace QodeAssist::Context diff --git a/context/ContextManager.hpp b/context/ContextManager.hpp index 78e60f7..4f25ecc 100644 --- a/context/ContextManager.hpp +++ b/context/ContextManager.hpp @@ -23,6 +23,7 @@ #include #include "ContentFile.hpp" +#include "ProgrammingLanguage.hpp" namespace QodeAssist::Context { @@ -34,6 +35,8 @@ public: static ContextManager &instance(); QString readFile(const QString &filePath) const; QList getContentFiles(const QStringList &filePaths) const; + ProgrammingLanguage getDocumentLanguage(const QJsonObject &request) const; + bool isSpecifyCompletion(const QJsonObject &request); private: explicit ContextManager(QObject *parent = nullptr); diff --git a/llmcore/Provider.hpp b/llmcore/Provider.hpp index 5c820a9..4e135a0 100644 --- a/llmcore/Provider.hpp +++ b/llmcore/Provider.hpp @@ -42,7 +42,6 @@ public: virtual QString completionEndpoint() const = 0; virtual QString chatEndpoint() const = 0; virtual bool supportsModelListing() const = 0; - virtual void prepareRequest( QJsonObject &request, LLMCore::PromptTemplate *prompt, diff --git a/llmcore/ProviderID.hpp b/llmcore/ProviderID.hpp index 871fb26..b15a484 100644 --- a/llmcore/ProviderID.hpp +++ b/llmcore/ProviderID.hpp @@ -27,7 +27,7 @@ enum class ProviderID { OpenAI, OpenAICompatible, MistralAI, - OpenRouter + OpenRouter, + GoogleAI }; - } diff --git a/providers/GoogleAIProvider.cpp b/providers/GoogleAIProvider.cpp new file mode 100644 index 0000000..d0815ad --- /dev/null +++ b/providers/GoogleAIProvider.cpp @@ -0,0 +1,303 @@ +/* + * Copyright (C) 2024 Petr Mironychev + * + * This file is part of QodeAssist. + * + * QodeAssist is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * QodeAssist is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with QodeAssist. If not, see . + */ + +#include "GoogleAIProvider.hpp" + +#include +#include +#include +#include +#include +#include + +#include "llmcore/ValidationUtils.hpp" +#include "logger/Logger.hpp" +#include "settings/ChatAssistantSettings.hpp" +#include "settings/CodeCompletionSettings.hpp" +#include "settings/GeneralSettings.hpp" +#include "settings/ProviderSettings.hpp" + +namespace QodeAssist::Providers { + +QString GoogleAIProvider::name() const +{ + return "Google AI"; +} + +QString GoogleAIProvider::url() const +{ + return "https://generativelanguage.googleapis.com/v1beta"; +} + +QString GoogleAIProvider::completionEndpoint() const +{ + return {}; +} + +QString GoogleAIProvider::chatEndpoint() const +{ + return {}; +} + +bool GoogleAIProvider::supportsModelListing() const +{ + return true; +} + +void GoogleAIProvider::prepareRequest( + QJsonObject &request, + LLMCore::PromptTemplate *prompt, + LLMCore::ContextData context, + LLMCore::RequestType type) +{ + if (!prompt->isSupportProvider(providerID())) { + LOG_MESSAGE(QString("Template %1 doesn't support %2 provider").arg(name(), prompt->name())); + } + + prompt->prepareRequest(request, context); + + auto applyModelParams = [&request](const auto &settings) { + QJsonObject generationConfig; + generationConfig["maxOutputTokens"] = settings.maxTokens(); + generationConfig["temperature"] = settings.temperature(); + + if (settings.useTopP()) + generationConfig["topP"] = settings.topP(); + if (settings.useTopK()) + generationConfig["topK"] = settings.topK(); + + request["generationConfig"] = generationConfig; + }; + + if (type == LLMCore::RequestType::CodeCompletion) { + applyModelParams(Settings::codeCompletionSettings()); + } else { + applyModelParams(Settings::chatAssistantSettings()); + } +} + +bool GoogleAIProvider::handleResponse(QNetworkReply *reply, QString &accumulatedResponse) +{ + if (reply->isFinished()) { + if (reply->bytesAvailable() > 0) { + QByteArray data = reply->readAll(); + + if (data.startsWith("data: ")) { + return handleStreamResponse(data, accumulatedResponse); + } else { + return handleRegularResponse(data, accumulatedResponse); + } + } + + return true; + } + + QByteArray data = reply->readAll(); + if (data.isEmpty()) { + return false; + } + + if (data.startsWith("data: ")) { + return handleStreamResponse(data, accumulatedResponse); + } else { + return handleRegularResponse(data, accumulatedResponse); + } +} + +QList GoogleAIProvider::getInstalledModels(const QString &url) +{ + QList models; + + QNetworkAccessManager manager; + QNetworkRequest request(QString("%1/models?key=%2").arg(url, apiKey())); + + request.setHeader(QNetworkRequest::ContentTypeHeader, "application/json"); + + QNetworkReply *reply = manager.get(request); + QEventLoop loop; + QObject::connect(reply, &QNetworkReply::finished, &loop, &QEventLoop::quit); + loop.exec(); + + if (reply->error() == QNetworkReply::NoError) { + QByteArray responseData = reply->readAll(); + QJsonDocument jsonResponse = QJsonDocument::fromJson(responseData); + QJsonObject jsonObject = jsonResponse.object(); + + if (jsonObject.contains("models")) { + QJsonArray modelArray = jsonObject["models"].toArray(); + models.clear(); + + for (const QJsonValue &value : modelArray) { + QJsonObject modelObject = value.toObject(); + if (modelObject.contains("name")) { + QString modelName = modelObject["name"].toString(); + if (modelName.contains("/")) { + modelName = modelName.split("/").last(); + } + models.append(modelName); + } + } + } + } else { + LOG_MESSAGE(QString("Error fetching Google AI models: %1").arg(reply->errorString())); + } + + reply->deleteLater(); + return models; +} + +QList GoogleAIProvider::validateRequest( + const QJsonObject &request, LLMCore::TemplateType type) +{ + QJsonObject templateReq; + + templateReq = QJsonObject{ + {"contents", QJsonArray{}}, + {"system_instruction", QJsonArray{}}, + {"generationConfig", + QJsonObject{{"temperature", {}}, {"maxOutputTokens", {}}, {"topP", {}}, {"topK", {}}}}, + {"safetySettings", QJsonArray{}}}; + + return LLMCore::ValidationUtils::validateRequestFields(request, templateReq); +} + +QString GoogleAIProvider::apiKey() const +{ + return Settings::providerSettings().googleAiApiKey(); +} + +void GoogleAIProvider::prepareNetworkRequest(QNetworkRequest &networkRequest) const +{ + networkRequest.setHeader(QNetworkRequest::ContentTypeHeader, "application/json"); + + QUrl url = networkRequest.url(); + QUrlQuery query(url.query()); + query.addQueryItem("key", apiKey()); + url.setQuery(query); + networkRequest.setUrl(url); +} + +LLMCore::ProviderID GoogleAIProvider::providerID() const +{ + return LLMCore::ProviderID::GoogleAI; +} + +bool GoogleAIProvider::handleStreamResponse(const QByteArray &data, QString &accumulatedResponse) +{ + QByteArrayList lines = data.split('\n'); + bool isDone = false; + + for (const QByteArray &line : lines) { + QByteArray trimmedLine = line.trimmed(); + if (trimmedLine.isEmpty()) { + continue; + } + + if (trimmedLine == "data: [DONE]") { + isDone = true; + continue; + } + + if (trimmedLine.startsWith("data: ")) { + QByteArray jsonData = trimmedLine.mid(6); // Remove "data: " prefix + QJsonDocument doc = QJsonDocument::fromJson(jsonData); + if (doc.isNull() || !doc.isObject()) { + continue; + } + + QJsonObject responseObj = doc.object(); + + if (responseObj.contains("error")) { + QJsonObject error = responseObj["error"].toObject(); + LOG_MESSAGE("Error in Google AI stream response: " + error["message"].toString()); + continue; + } + + if (responseObj.contains("candidates")) { + QJsonArray candidates = responseObj["candidates"].toArray(); + if (!candidates.isEmpty()) { + QJsonObject candidate = candidates.first().toObject(); + + if (candidate.contains("finishReason") + && !candidate["finishReason"].toString().isEmpty()) { + isDone = true; + } + + if (candidate.contains("content")) { + QJsonObject content = candidate["content"].toObject(); + if (content.contains("parts")) { + QJsonArray parts = content["parts"].toArray(); + for (const auto &part : parts) { + QJsonObject partObj = part.toObject(); + if (partObj.contains("text")) { + accumulatedResponse += partObj["text"].toString(); + } + } + } + } + } + } + } + } + + return isDone; +} + +bool GoogleAIProvider::handleRegularResponse(const QByteArray &data, QString &accumulatedResponse) +{ + QJsonDocument doc = QJsonDocument::fromJson(data); + if (doc.isNull() || !doc.isObject()) { + LOG_MESSAGE("Invalid JSON response from Google AI API"); + return false; + } + + QJsonObject response = doc.object(); + + if (response.contains("error")) { + QJsonObject error = response["error"].toObject(); + LOG_MESSAGE("Error in Google AI response: " + error["message"].toString()); + return false; + } + + if (!response.contains("candidates") || response["candidates"].toArray().isEmpty()) { + return false; + } + + QJsonObject candidate = response["candidates"].toArray().first().toObject(); + if (!candidate.contains("content")) { + return false; + } + + QJsonObject content = candidate["content"].toObject(); + if (!content.contains("parts")) { + return false; + } + + QJsonArray parts = content["parts"].toArray(); + for (const auto &part : parts) { + QJsonObject partObj = part.toObject(); + if (partObj.contains("text")) { + accumulatedResponse += partObj["text"].toString(); + } + } + + return true; +} + +} // namespace QodeAssist::Providers diff --git a/providers/GoogleAIProvider.hpp b/providers/GoogleAIProvider.hpp new file mode 100644 index 0000000..3dc03df --- /dev/null +++ b/providers/GoogleAIProvider.hpp @@ -0,0 +1,51 @@ +/* + * Copyright (C) 2024 Petr Mironychev + * + * This file is part of QodeAssist. + * + * QodeAssist is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * QodeAssist is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with QodeAssist. If not, see . + */ + +#pragma once + +#include "llmcore/Provider.hpp" + +namespace QodeAssist::Providers { + +class GoogleAIProvider : public LLMCore::Provider +{ +public: + QString name() const override; + QString url() const override; + QString completionEndpoint() const override; + QString chatEndpoint() const override; + bool supportsModelListing() const override; + void prepareRequest( + QJsonObject &request, + LLMCore::PromptTemplate *prompt, + LLMCore::ContextData context, + LLMCore::RequestType type) override; + bool handleResponse(QNetworkReply *reply, QString &accumulatedResponse) override; + QList getInstalledModels(const QString &url) override; + QList validateRequest(const QJsonObject &request, LLMCore::TemplateType type) override; + QString apiKey() const override; + void prepareNetworkRequest(QNetworkRequest &networkRequest) const override; + LLMCore::ProviderID providerID() const override; + +private: + bool handleStreamResponse(const QByteArray &data, QString &accumulatedResponse); + bool handleRegularResponse(const QByteArray &data, QString &accumulatedResponse); +}; + +} // namespace QodeAssist::Providers diff --git a/providers/Providers.hpp b/providers/Providers.hpp index 81fdd28..0a6999d 100644 --- a/providers/Providers.hpp +++ b/providers/Providers.hpp @@ -21,6 +21,7 @@ #include "llmcore/ProvidersManager.hpp" #include "providers/ClaudeProvider.hpp" +#include "providers/GoogleAIProvider.hpp" #include "providers/LMStudioProvider.hpp" #include "providers/MistralAIProvider.hpp" #include "providers/OllamaProvider.hpp" @@ -40,6 +41,7 @@ inline void registerProviders() providerManager.registerProvider(); providerManager.registerProvider(); providerManager.registerProvider(); + providerManager.registerProvider(); } } // namespace QodeAssist::Providers diff --git a/settings/ProviderSettings.cpp b/settings/ProviderSettings.cpp index db5422a..890105a 100644 --- a/settings/ProviderSettings.cpp +++ b/settings/ProviderSettings.cpp @@ -87,6 +87,15 @@ ProviderSettings::ProviderSettings() mistralAiApiKey.setDefaultValue(""); mistralAiApiKey.setAutoApply(true); + // GoogleAI Settings + googleAiApiKey.setSettingsKey(Constants::GOOGLE_AI_API_KEY); + googleAiApiKey.setLabelText(Tr::tr("Google AI API Key:")); + googleAiApiKey.setDisplayStyle(Utils::StringAspect::LineEditDisplay); + googleAiApiKey.setPlaceHolderText(Tr::tr("Enter your API key here")); + googleAiApiKey.setHistoryCompleter(Constants::GOOGLE_AI_API_KEY_HISTORY); + googleAiApiKey.setDefaultValue(""); + googleAiApiKey.setAutoApply(true); + resetToDefaults.m_buttonText = Tr::tr("Reset Page to Defaults"); readSettings(); @@ -108,6 +117,8 @@ ProviderSettings::ProviderSettings() Group{title(Tr::tr("Claude Settings")), Column{claudeApiKey}}, Space{8}, Group{title(Tr::tr("Mistral AI Settings")), Column{mistralAiApiKey}}, + Space{8}, + Group{title(Tr::tr("Google AI Settings")), Column{googleAiApiKey}}, Stretch{1}}; }); } @@ -140,6 +151,8 @@ void ProviderSettings::resetSettingsToDefaults() resetAspect(openAiCompatApiKey); resetAspect(claudeApiKey); resetAspect(openAiApiKey); + resetAspect(mistralAiApiKey); + resetAspect(googleAiApiKey); } } diff --git a/settings/ProviderSettings.hpp b/settings/ProviderSettings.hpp index 26cb9b3..9e84b89 100644 --- a/settings/ProviderSettings.hpp +++ b/settings/ProviderSettings.hpp @@ -38,6 +38,7 @@ public: Utils::StringAspect claudeApiKey{this}; Utils::StringAspect openAiApiKey{this}; Utils::StringAspect mistralAiApiKey{this}; + Utils::StringAspect googleAiApiKey{this}; private: void setupConnections(); diff --git a/settings/SettingsConstants.hpp b/settings/SettingsConstants.hpp index 5166c85..74336a5 100644 --- a/settings/SettingsConstants.hpp +++ b/settings/SettingsConstants.hpp @@ -102,6 +102,8 @@ const char OPEN_AI_API_KEY[] = "QodeAssist.openAiApiKey"; const char OPEN_AI_API_KEY_HISTORY[] = "QodeAssist.openAiApiKeyHistory"; const char MISTRAL_AI_API_KEY[] = "QodeAssist.mistralAiApiKey"; const char MISTRAL_AI_API_KEY_HISTORY[] = "QodeAssist.mistralAiApiKeyHistory"; +const char GOOGLE_AI_API_KEY[] = "QodeAssist.googleAiApiKey"; +const char GOOGLE_AI_API_KEY_HISTORY[] = "QodeAssist.googleAiApiKeyHistory"; // context settings const char CC_READ_FULL_FILE[] = "QodeAssist.ccReadFullFile"; diff --git a/templates/GoogleAI.hpp b/templates/GoogleAI.hpp new file mode 100644 index 0000000..28932c9 --- /dev/null +++ b/templates/GoogleAI.hpp @@ -0,0 +1,72 @@ +/* + * Copyright (C) 2024 Petr Mironychev + * + * This file is part of QodeAssist. + * + * QodeAssist is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * QodeAssist is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with QodeAssist. If not, see . + */ + +#pragma once + +#include +#include + +#include "llmcore/PromptTemplate.hpp" + +namespace QodeAssist::Templates { + +class GoogleAI : public LLMCore::PromptTemplate +{ +public: + LLMCore::TemplateType type() const override { return LLMCore::TemplateType::Chat; } + QString name() const override { return "Google AI"; } + QStringList stopWords() const override { return QStringList(); } + + void prepareRequest(QJsonObject &request, const LLMCore::ContextData &context) const override + { + QJsonArray contents; + + if (context.systemPrompt && !context.systemPrompt->isEmpty()) { + request["system_instruction"] = QJsonObject{ + {"parts", QJsonObject{{"text", context.systemPrompt.value()}}}}; + } + + for (const auto &msg : context.history.value()) { + QJsonObject content; + QJsonArray parts; + + parts.append(QJsonObject{{"text", msg.content}}); + + QString role = msg.role; + if (role == "assistant") { + role = "model"; + } + + content["role"] = role; + content["parts"] = parts; + contents.append(content); + } + + request["contents"] = contents; + } + + QString description() const override { return "Google AI (Gemini)"; } + + bool isSupportProvider(LLMCore::ProviderID id) const override + { + return id == QodeAssist::LLMCore::ProviderID::GoogleAI; + } +}; + +} // namespace QodeAssist::Templates diff --git a/templates/Templates.hpp b/templates/Templates.hpp index d9e9f95..04775fd 100644 --- a/templates/Templates.hpp +++ b/templates/Templates.hpp @@ -31,6 +31,7 @@ #include "templates/OpenAICompatible.hpp" // #include "templates/CustomFimTemplate.hpp" // #include "templates/DeepSeekCoderFim.hpp" +#include "templates/GoogleAI.hpp" #include "templates/Llama2.hpp" #include "templates/Llama3.hpp" #include "templates/Qwen.hpp" @@ -58,6 +59,7 @@ inline void registerTemplates() templateManager.registerTemplate(); templateManager.registerTemplate(); templateManager.registerTemplate(); + templateManager.registerTemplate(); } } // namespace QodeAssist::Templates