diff --git a/CMakeLists.txt b/CMakeLists.txt index 8576f60..12adba6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -53,13 +53,15 @@ add_qtc_plugin(QodeAssist templates/ChatML.hpp templates/Alpaca.hpp templates/Llama2.hpp - providers/Providers.hpp templates/Claude.hpp + templates/OpenAI.hpp + providers/Providers.hpp providers/OllamaProvider.hpp providers/OllamaProvider.cpp providers/LMStudioProvider.hpp providers/LMStudioProvider.cpp providers/OpenAICompatProvider.hpp providers/OpenAICompatProvider.cpp providers/OpenRouterAIProvider.hpp providers/OpenRouterAIProvider.cpp providers/ClaudeProvider.hpp providers/ClaudeProvider.cpp + providers/OpenAIProvider.hpp providers/OpenAIProvider.cpp QodeAssist.qrc LSPCompletion.hpp LLMSuggestion.hpp LLMSuggestion.cpp diff --git a/providers/OpenAIProvider.cpp b/providers/OpenAIProvider.cpp new file mode 100644 index 0000000..f090f58 --- /dev/null +++ b/providers/OpenAIProvider.cpp @@ -0,0 +1,229 @@ +/* + * Copyright (C) 2024 Petr Mironychev + * + * This file is part of QodeAssist. + * + * QodeAssist is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * QodeAssist is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with QodeAssist. If not, see . + */ + +#include "OpenAIProvider.hpp" + +#include "settings/ChatAssistantSettings.hpp" +#include "settings/CodeCompletionSettings.hpp" +#include "settings/ProviderSettings.hpp" + +#include +#include +#include +#include +#include + +#include "llmcore/OpenAIMessage.hpp" +#include "llmcore/ValidationUtils.hpp" +#include "logger/Logger.hpp" + +namespace QodeAssist::Providers { + +OpenAIProvider::OpenAIProvider() {} + +QString OpenAIProvider::name() const +{ + return "OpenAI"; +} + +QString OpenAIProvider::url() const +{ + return "https://api.openai.com"; +} + +QString OpenAIProvider::completionEndpoint() const +{ + return "/v1/chat/completions"; +} + +QString OpenAIProvider::chatEndpoint() const +{ + return "/v1/chat/completions"; +} + +bool OpenAIProvider::supportsModelListing() const +{ + return true; +} + +void OpenAIProvider::prepareRequest(QJsonObject &request, LLMCore::RequestType type) +{ + auto prepareMessages = [](QJsonObject &req) -> QJsonArray { + QJsonArray messages; + if (req.contains("system")) { + messages.append( + QJsonObject{{"role", "system"}, {"content", req.take("system").toString()}}); + } + if (req.contains("prompt")) { + messages.append( + QJsonObject{{"role", "user"}, {"content", req.take("prompt").toString()}}); + } + return messages; + }; + + auto applyModelParams = [&request](const auto &settings) { + request["max_tokens"] = settings.maxTokens(); + request["temperature"] = settings.temperature(); + + if (settings.useTopP()) + request["top_p"] = settings.topP(); + if (settings.useTopK()) + request["top_k"] = settings.topK(); + if (settings.useFrequencyPenalty()) + request["frequency_penalty"] = settings.frequencyPenalty(); + if (settings.usePresencePenalty()) + request["presence_penalty"] = settings.presencePenalty(); + }; + + QJsonArray messages = prepareMessages(request); + if (!messages.isEmpty()) { + request["messages"] = std::move(messages); + } + + if (type == LLMCore::RequestType::CodeCompletion) { + applyModelParams(Settings::codeCompletionSettings()); + } else { + applyModelParams(Settings::chatAssistantSettings()); + } +} + +bool OpenAIProvider::handleResponse(QNetworkReply *reply, QString &accumulatedResponse) +{ + QByteArray data = reply->readAll(); + if (data.isEmpty()) { + return false; + } + + bool isDone = false; + QByteArrayList lines = data.split('\n'); + + for (const QByteArray &line : lines) { + if (line.trimmed().isEmpty()) { + continue; + } + + if (line == "data: [DONE]") { + isDone = true; + continue; + } + + QByteArray jsonData = line; + if (line.startsWith("data: ")) { + jsonData = line.mid(6); + } + + QJsonParseError error; + QJsonDocument doc = QJsonDocument::fromJson(jsonData, &error); + + if (doc.isNull()) { + continue; + } + + auto message = LLMCore::OpenAIMessage::fromJson(doc.object()); + if (message.hasError()) { + LOG_MESSAGE("Error in OpenAI response: " + message.error); + continue; + } + + QString content = message.getContent(); + if (!content.isEmpty()) { + accumulatedResponse += content; + } + + if (message.isDone()) { + isDone = true; + } + } + + return isDone; +} + +QList OpenAIProvider::getInstalledModels(const QString &url) +{ + QList models; + QNetworkAccessManager manager; + QNetworkRequest request(QString("%1/v1/models").arg(url)); + + request.setHeader(QNetworkRequest::ContentTypeHeader, "application/json"); + if (!apiKey().isEmpty()) { + request.setRawHeader("Authorization", QString("Bearer %1").arg(apiKey()).toUtf8()); + } + + QNetworkReply *reply = manager.get(request); + QEventLoop loop; + QObject::connect(reply, &QNetworkReply::finished, &loop, &QEventLoop::quit); + loop.exec(); + + if (reply->error() == QNetworkReply::NoError) { + QByteArray responseData = reply->readAll(); + QJsonDocument jsonResponse = QJsonDocument::fromJson(responseData); + QJsonObject jsonObject = jsonResponse.object(); + + if (jsonObject.contains("data")) { + QJsonArray modelArray = jsonObject["data"].toArray(); + for (const QJsonValue &value : modelArray) { + QJsonObject modelObject = value.toObject(); + if (modelObject.contains("id")) { + QString modelId = modelObject["id"].toString(); + if (modelId.startsWith("gpt")) { + models.append(modelId); + } + } + } + } + } else { + LOG_MESSAGE(QString("Error fetching ChatGPT models: %1").arg(reply->errorString())); + } + + reply->deleteLater(); + return models; +} + +QList OpenAIProvider::validateRequest(const QJsonObject &request, LLMCore::TemplateType type) +{ + const auto templateReq = QJsonObject{ + {"model", {}}, + {"messages", QJsonArray{{QJsonObject{{"role", {}}, {"content", {}}}}}}, + {"temperature", {}}, + {"max_tokens", {}}, + {"top_p", {}}, + {"top_k", {}}, + {"frequency_penalty", {}}, + {"presence_penalty", {}}, + {"stop", QJsonArray{}}, + {"stream", {}}}; + + return LLMCore::ValidationUtils::validateRequestFields(request, templateReq); +} + +QString OpenAIProvider::apiKey() const +{ + return Settings::providerSettings().openAiCompatApiKey(); +} + +void OpenAIProvider::prepareNetworkRequest(QNetworkRequest &networkRequest) const +{ + networkRequest.setHeader(QNetworkRequest::ContentTypeHeader, "application/json"); + + if (!apiKey().isEmpty()) { + networkRequest.setRawHeader("Authorization", QString("Bearer %1").arg(apiKey()).toUtf8()); + } +} + +} // namespace QodeAssist::Providers diff --git a/providers/OpenAIProvider.hpp b/providers/OpenAIProvider.hpp new file mode 100644 index 0000000..d47b1e1 --- /dev/null +++ b/providers/OpenAIProvider.hpp @@ -0,0 +1,44 @@ +/* + * Copyright (C) 2024 Petr Mironychev + * + * This file is part of QodeAssist. + * + * QodeAssist is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * QodeAssist is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with QodeAssist. If not, see . + */ + +#pragma once + +#include "llmcore/Provider.hpp" + +namespace QodeAssist::Providers { + +class OpenAIProvider : public LLMCore::Provider +{ +public: + OpenAIProvider(); + + QString name() const override; + QString url() const override; + QString completionEndpoint() const override; + QString chatEndpoint() const override; + bool supportsModelListing() const override; + void prepareRequest(QJsonObject &request, LLMCore::RequestType type) override; + bool handleResponse(QNetworkReply *reply, QString &accumulatedResponse) override; + QList getInstalledModels(const QString &url) override; + QList validateRequest(const QJsonObject &request, LLMCore::TemplateType type) override; + QString apiKey() const override; + void prepareNetworkRequest(QNetworkRequest &networkRequest) const override; +}; + +} // namespace QodeAssist::Providers diff --git a/providers/Providers.hpp b/providers/Providers.hpp index a287adf..811cb1e 100644 --- a/providers/Providers.hpp +++ b/providers/Providers.hpp @@ -24,6 +24,7 @@ #include "providers/LMStudioProvider.hpp" #include "providers/OllamaProvider.hpp" #include "providers/OpenAICompatProvider.hpp" +#include "providers/OpenAIProvider.hpp" #include "providers/OpenRouterAIProvider.hpp" namespace QodeAssist::Providers { @@ -36,6 +37,7 @@ inline void registerProviders() providerManager.registerProvider(); providerManager.registerProvider(); providerManager.registerProvider(); + providerManager.registerProvider(); } } // namespace QodeAssist::Providers diff --git a/templates/OpenAI.hpp b/templates/OpenAI.hpp new file mode 100644 index 0000000..e5bc467 --- /dev/null +++ b/templates/OpenAI.hpp @@ -0,0 +1,39 @@ +/* + * Copyright (C) 2024 Petr Mironychev + * + * This file is part of QodeAssist. + * + * QodeAssist is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * QodeAssist is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with QodeAssist. If not, see . + */ + +#pragma once + +#include + +#include "llmcore/PromptTemplate.hpp" + +namespace QodeAssist::Templates { + +class OpenAI : public LLMCore::PromptTemplate +{ +public: + LLMCore::TemplateType type() const override { return LLMCore::TemplateType::Chat; } + QString name() const override { return "OpenAI"; } + QString promptTemplate() const override { return {}; } + QStringList stopWords() const override { return QStringList(); } + void prepareRequest(QJsonObject &request, const LLMCore::ContextData &context) const override {} + QString description() const override { return "OpenAI"; } +}; + +} // namespace QodeAssist::Templates diff --git a/templates/Templates.hpp b/templates/Templates.hpp index f7e4f01..80f215c 100644 --- a/templates/Templates.hpp +++ b/templates/Templates.hpp @@ -30,6 +30,7 @@ #include "templates/Llama2.hpp" #include "templates/Llama3.hpp" #include "templates/Ollama.hpp" +#include "templates/OpenAI.hpp" #include "templates/Qwen.hpp" #include "templates/StarCoder2Fim.hpp" @@ -51,6 +52,7 @@ inline void registerTemplates() templateManager.registerTemplate(); templateManager.registerTemplate(); templateManager.registerTemplate(); + templateManager.registerTemplate(); } } // namespace QodeAssist::Templates