From c070fd5cfdec9b8917ea5ff7b9a9c9ca4c86ea36 Mon Sep 17 00:00:00 2001 From: Petr Mironychev <9195189+Palm1r@users.noreply.github.com> Date: Tue, 10 Dec 2024 21:27:56 +0100 Subject: [PATCH] :sparkles: feat: Add OpenRouter provider --- CMakeLists.txt | 1 + llmcore/OpenAIMessage.cpp | 45 ++--------- llmcore/OpenAIMessage.hpp | 4 +- providers/LMStudioProvider.cpp | 33 ++++++-- providers/OpenAICompatProvider.cpp | 95 +++++++--------------- providers/OpenRouterProvider.cpp | 126 +++++++++++++++++++++++++++++ providers/OpenrouterProvider.hpp | 38 +++++++++ providers/Providers.hpp | 2 + 8 files changed, 233 insertions(+), 111 deletions(-) create mode 100644 providers/OpenRouterProvider.cpp create mode 100644 providers/OpenrouterProvider.hpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 2a623f1..a3c45ed 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -56,6 +56,7 @@ add_qtc_plugin(QodeAssist providers/OllamaProvider.hpp providers/OllamaProvider.cpp providers/LMStudioProvider.hpp providers/LMStudioProvider.cpp providers/OpenAICompatProvider.hpp providers/OpenAICompatProvider.cpp + providers/OpenRouterProvider.hpp providers/OpenRouterProvider.cpp QodeAssist.qrc LSPCompletion.hpp LLMSuggestion.hpp LLMSuggestion.cpp diff --git a/llmcore/OpenAIMessage.cpp b/llmcore/OpenAIMessage.cpp index e100dfc..825e12e 100644 --- a/llmcore/OpenAIMessage.cpp +++ b/llmcore/OpenAIMessage.cpp @@ -20,46 +20,13 @@ #include "OpenAIMessage.hpp" #include #include -#include namespace QodeAssist::LLMCore { -OpenAIMessage OpenAIMessage::fromJson(const QByteArray &data) +OpenAIMessage OpenAIMessage::fromJson(const QJsonObject &obj) { OpenAIMessage msg; - QByteArrayList lines = data.split('\n'); - QByteArray jsonData; - - for (const QByteArray &line : lines) { - if (line.trimmed().isEmpty()) { - continue; - } - - if (line.trimmed() == "data: [DONE]") { - msg.done = true; - continue; - } - - if (line.startsWith("data: ")) { - jsonData = line.mid(6); - break; - } - } - - if (jsonData.isEmpty()) { - jsonData = data; - } - - QJsonParseError error; - QJsonDocument doc = QJsonDocument::fromJson(jsonData, &error); - if (doc.isNull()) { - msg.error = QString("Invalid JSON response: %1").arg(error.errorString()); - return msg; - } - - QJsonObject obj = doc.object(); - if (obj.contains("error")) { msg.error = obj["error"].toObject()["message"].toString(); return msg; @@ -70,10 +37,12 @@ OpenAIMessage OpenAIMessage::fromJson(const QByteArray &data) if (!choices.isEmpty()) { auto choiceObj = choices[0].toObject(); - if (choiceObj.contains("message")) { - msg.choice.content = choiceObj["message"].toObject()["content"].toString(); - } else if (choiceObj.contains("delta")) { - msg.choice.content = choiceObj["delta"].toObject()["content"].toString(); + if (choiceObj.contains("delta")) { + QJsonObject delta = choiceObj["delta"].toObject(); + msg.choice.content = delta["content"].toString(); + } else if (choiceObj.contains("message")) { + QJsonObject message = choiceObj["message"].toObject(); + msg.choice.content = message["content"].toString(); } msg.choice.finishReason = choiceObj["finish_reason"].toString(); diff --git a/llmcore/OpenAIMessage.hpp b/llmcore/OpenAIMessage.hpp index 43c25a7..5b66b5a 100644 --- a/llmcore/OpenAIMessage.hpp +++ b/llmcore/OpenAIMessage.hpp @@ -46,13 +46,11 @@ public: bool done{false}; Usage usage; - static OpenAIMessage fromJson(const QByteArray &data); QString getContent() const; bool hasError() const; bool isDone() const; -private: - static OpenAIMessage fromJsonObject(const QJsonObject &obj); + static OpenAIMessage fromJson(const QJsonObject &obj); }; } // namespace QodeAssist::LLMCore diff --git a/providers/LMStudioProvider.cpp b/providers/LMStudioProvider.cpp index eec30f4..68de130 100644 --- a/providers/LMStudioProvider.cpp +++ b/providers/LMStudioProvider.cpp @@ -107,14 +107,35 @@ bool LMStudioProvider::handleResponse(QNetworkReply *reply, QString &accumulated return false; } - auto message = LLMCore::OpenAIMessage::fromJson(data); - if (message.hasError()) { - LOG_MESSAGE("Error in OpenAI response: " + message.error); - return false; + QByteArrayList chunks = data.split('\n'); + for (const QByteArray &chunk : chunks) { + if (chunk.trimmed().isEmpty() || chunk == "data: [DONE]") { + continue; + } + + QByteArray jsonData = chunk; + if (chunk.startsWith("data: ")) { + jsonData = chunk.mid(6); + } + + QJsonParseError error; + QJsonDocument doc = QJsonDocument::fromJson(jsonData, &error); + + if (doc.isNull()) { + continue; + } + + auto message = LLMCore::OpenAIMessage::fromJson(doc.object()); + if (message.hasError()) { + LOG_MESSAGE("Error in LMStudioProvider response: " + message.error); + continue; + } + + accumulatedResponse += message.getContent(); + return message.isDone(); } - accumulatedResponse += message.getContent(); - return message.isDone(); + return false; } QList LMStudioProvider::getInstalledModels(const QString &url) diff --git a/providers/OpenAICompatProvider.cpp b/providers/OpenAICompatProvider.cpp index 5f941da..2dee0be 100644 --- a/providers/OpenAICompatProvider.cpp +++ b/providers/OpenAICompatProvider.cpp @@ -26,6 +26,7 @@ #include #include +#include "llmcore/OpenAIMessage.hpp" #include "logger/Logger.hpp" namespace QodeAssist::Providers { @@ -100,74 +101,40 @@ void OpenAICompatProvider::prepareRequest(QJsonObject &request, LLMCore::Request bool OpenAICompatProvider::handleResponse(QNetworkReply *reply, QString &accumulatedResponse) { - bool isComplete = false; - QString tempResponse = accumulatedResponse; - - while (reply->canReadLine()) { - QByteArray line = reply->readLine().trimmed(); - if (line.isEmpty()) { - continue; - } - - if (!line.startsWith("data:")) { - continue; - } - - line = line.mid(6); - - if (line == "[DONE]") { - isComplete = true; - break; - } - - QJsonDocument jsonResponse = QJsonDocument::fromJson(line); - if (jsonResponse.isNull()) { - LOG_MESSAGE( - "Invalid JSON response from OpenAI compatible provider: " + QString::fromUtf8(line)); - continue; - } - - QJsonObject responseObj = jsonResponse.object(); - - if (responseObj.contains("error")) { - LOG_MESSAGE( - "OpenAI compatible provider error: " - + QString::fromUtf8(QJsonDocument(responseObj).toJson(QJsonDocument::Indented))); - return false; - } - - if (responseObj.contains("choices")) { - QJsonArray choices = responseObj["choices"].toArray(); - if (!choices.isEmpty()) { - QJsonObject choice = choices.first().toObject(); - QJsonObject delta = choice["delta"].toObject(); - if (delta.contains("content")) { - QString completion = delta["content"].toString(); - if (!completion.isEmpty()) { - tempResponse += completion; - } - } - QString finishReason = choice["finish_reason"].toString(); - if (!finishReason.isNull() && finishReason == "stop") { - isComplete = true; - } - } - } - - if (responseObj.contains("usage")) { - QJsonObject usage = responseObj["usage"].toObject(); - LOG_MESSAGE(QString("Token usage - Prompt: %1, Completion: %2, Total: %3") - .arg(usage["prompt_tokens"].toInt()) - .arg(usage["completion_tokens"].toInt()) - .arg(usage["total_tokens"].toInt())); - } + QByteArray data = reply->readAll(); + if (data.isEmpty()) { + return false; } - if (!tempResponse.isEmpty()) { - accumulatedResponse = tempResponse; + QByteArrayList chunks = data.split('\n'); + for (const QByteArray &chunk : chunks) { + if (chunk.trimmed().isEmpty() || chunk == "data: [DONE]") { + continue; + } + + QByteArray jsonData = chunk; + if (chunk.startsWith("data: ")) { + jsonData = chunk.mid(6); + } + + QJsonParseError error; + QJsonDocument doc = QJsonDocument::fromJson(jsonData, &error); + + if (doc.isNull()) { + continue; + } + + auto message = LLMCore::OpenAIMessage::fromJson(doc.object()); + if (message.hasError()) { + LOG_MESSAGE("Error in OpenAI response: " + message.error); + continue; + } + + accumulatedResponse += message.getContent(); + return message.isDone(); } - return isComplete; + return false; } QList OpenAICompatProvider::getInstalledModels(const QString &url) diff --git a/providers/OpenRouterProvider.cpp b/providers/OpenRouterProvider.cpp new file mode 100644 index 0000000..5485cbf --- /dev/null +++ b/providers/OpenRouterProvider.cpp @@ -0,0 +1,126 @@ +/* + * Copyright (C) 2024 Petr Mironychev + * + * This file is part of QodeAssist. + * + * QodeAssist is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * QodeAssist is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with QodeAssist. If not, see . + */ + +#include "OpenRouterProvider.hpp" +#include "settings/ChatAssistantSettings.hpp" +#include "settings/CodeCompletionSettings.hpp" + +#include +#include +#include +#include + +#include "llmcore/OpenAIMessage.hpp" +#include "logger/Logger.hpp" + +namespace QodeAssist::Providers { + +OpenRouterProvider::OpenRouterProvider() {} + +QString OpenRouterProvider::name() const +{ + return "OpenRouter"; +} + +QString OpenRouterProvider::url() const +{ + return "https://openrouter.ai/api"; +} + +void OpenRouterProvider::prepareRequest(QJsonObject &request, LLMCore::RequestType type) +{ + auto prepareMessages = [](QJsonObject &req) -> QJsonArray { + QJsonArray messages; + if (req.contains("system")) { + messages.append( + QJsonObject{{"role", "system"}, {"content", req.take("system").toString()}}); + } + if (req.contains("prompt")) { + messages.append( + QJsonObject{{"role", "user"}, {"content", req.take("prompt").toString()}}); + } + return messages; + }; + + auto applyModelParams = [&request](const auto &settings) { + request["max_tokens"] = settings.maxTokens(); + request["temperature"] = settings.temperature(); + + if (settings.useTopP()) + request["top_p"] = settings.topP(); + if (settings.useTopK()) + request["top_k"] = settings.topK(); + if (settings.useFrequencyPenalty()) + request["frequency_penalty"] = settings.frequencyPenalty(); + if (settings.usePresencePenalty()) + request["presence_penalty"] = settings.presencePenalty(); + }; + + QJsonArray messages = prepareMessages(request); + if (!messages.isEmpty()) { + request["messages"] = std::move(messages); + } + + if (type == LLMCore::RequestType::Fim) { + applyModelParams(Settings::codeCompletionSettings()); + } else { + applyModelParams(Settings::chatAssistantSettings()); + } +} + +bool OpenRouterProvider::handleResponse(QNetworkReply *reply, QString &accumulatedResponse) +{ + QByteArray data = reply->readAll(); + if (data.isEmpty()) { + return false; + } + + QByteArrayList chunks = data.split('\n'); + for (const QByteArray &chunk : chunks) { + if (chunk.trimmed().isEmpty() || chunk.contains("OPENROUTER PROCESSING") + || chunk == "data: [DONE]") { + continue; + } + + QByteArray jsonData = chunk; + if (chunk.startsWith("data: ")) { + jsonData = chunk.mid(6); + } + + QJsonParseError error; + QJsonDocument doc = QJsonDocument::fromJson(jsonData, &error); + + if (doc.isNull()) { + continue; + } + + auto message = LLMCore::OpenAIMessage::fromJson(doc.object()); + if (message.hasError()) { + LOG_MESSAGE("Error in OpenRouter response: " + message.error); + continue; + } + + accumulatedResponse += message.getContent(); + return message.isDone(); + } + + return false; +} + +} // namespace QodeAssist::Providers diff --git a/providers/OpenrouterProvider.hpp b/providers/OpenrouterProvider.hpp new file mode 100644 index 0000000..6f06945 --- /dev/null +++ b/providers/OpenrouterProvider.hpp @@ -0,0 +1,38 @@ +/* + * Copyright (C) 2024 Petr Mironychev + * + * This file is part of QodeAssist. + * + * QodeAssist is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * QodeAssist is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with QodeAssist. If not, see . + */ + +#pragma once + +#include "llmcore/Provider.hpp" +#include "providers/OpenAICompatProvider.hpp" + +namespace QodeAssist::Providers { + +class OpenRouterProvider : public OpenAICompatProvider +{ +public: + OpenRouterProvider(); + + QString name() const override; + QString url() const override; + void prepareRequest(QJsonObject &request, LLMCore::RequestType type) override; + bool handleResponse(QNetworkReply *reply, QString &accumulatedResponse) override; +}; + +} // namespace QodeAssist::Providers diff --git a/providers/Providers.hpp b/providers/Providers.hpp index 7bb614f..8eb691c 100644 --- a/providers/Providers.hpp +++ b/providers/Providers.hpp @@ -23,6 +23,7 @@ #include "providers/LMStudioProvider.hpp" #include "providers/OllamaProvider.hpp" #include "providers/OpenAICompatProvider.hpp" +#include "providers/OpenRouterProvider.hpp" namespace QodeAssist::Providers { @@ -32,6 +33,7 @@ inline void registerProviders() providerManager.registerProvider(); providerManager.registerProvider(); providerManager.registerProvider(); + providerManager.registerProvider(); } } // namespace QodeAssist::Providers