diff --git a/CMakeLists.txt b/CMakeLists.txt index 2a623f1..16f4fec 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -56,6 +56,7 @@ add_qtc_plugin(QodeAssist providers/OllamaProvider.hpp providers/OllamaProvider.cpp providers/LMStudioProvider.hpp providers/LMStudioProvider.cpp providers/OpenAICompatProvider.hpp providers/OpenAICompatProvider.cpp + providers/OllamaMessage.hpp providers/OllamaMessage.cpp QodeAssist.qrc LSPCompletion.hpp LLMSuggestion.hpp LLMSuggestion.cpp diff --git a/providers/OllamaMessage.cpp b/providers/OllamaMessage.cpp new file mode 100644 index 0000000..2e16f9f --- /dev/null +++ b/providers/OllamaMessage.cpp @@ -0,0 +1,76 @@ +/* + * Copyright (C) 2024 Petr Mironychev + * + * This file is part of QodeAssist. + * + * QodeAssist is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * QodeAssist is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with QodeAssist. If not, see . + */ + +#include "OllamaMessage.hpp" + +namespace QodeAssist::Providers { + +OllamaMessage OllamaMessage::fromJson(const QJsonObject &obj, Type type) +{ + OllamaMessage msg; + msg.model = obj["model"].toString(); + msg.createdAt = QDateTime::fromString(obj["created_at"].toString(), Qt::ISODate); + msg.done = obj["done"].toBool(); + msg.doneReason = obj["done_reason"].toString(); + msg.error = obj["error"].toString(); + + if (type == Type::Generate) { + auto &genResponse = msg.response.emplace(); + genResponse.response = obj["response"].toString(); + if (msg.done && obj.contains("context")) { + const auto array = obj["context"].toArray(); + genResponse.context.reserve(array.size()); + for (const auto &val : array) { + genResponse.context.append(val.toInt()); + } + } + } else { + auto &chatResponse = msg.response.emplace(); + const auto msgObj = obj["message"].toObject(); + chatResponse.role = msgObj["role"].toString(); + chatResponse.content = msgObj["content"].toString(); + } + + if (msg.done) { + msg.metrics + = {obj["total_duration"].toVariant().toLongLong(), + obj["load_duration"].toVariant().toLongLong(), + obj["prompt_eval_count"].toVariant().toLongLong(), + obj["prompt_eval_duration"].toVariant().toLongLong(), + obj["eval_count"].toVariant().toLongLong(), + obj["eval_duration"].toVariant().toLongLong()}; + } + + return msg; +} + +QString OllamaMessage::getContent() const +{ + if (std::holds_alternative(response)) { + return std::get(response).response; + } + return std::get(response).content; +} + +bool OllamaMessage::hasError() const +{ + return !error.isEmpty(); +} + +} // namespace QodeAssist::Providers diff --git a/providers/OllamaMessage.hpp b/providers/OllamaMessage.hpp new file mode 100644 index 0000000..8701a4a --- /dev/null +++ b/providers/OllamaMessage.hpp @@ -0,0 +1,71 @@ +/* + * Copyright (C) 2024 Petr Mironychev + * + * This file is part of QodeAssist. + * + * QodeAssist is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * QodeAssist is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with QodeAssist. If not, see . + */ + +#pragma once + +#include +#include +#include +#include + +namespace QodeAssist::Providers { + +class OllamaMessage +{ +public: + enum class Type { Generate, Chat }; + + struct Metrics + { + qint64 totalDuration{0}; + qint64 loadDuration{0}; + qint64 promptEvalCount{0}; + qint64 promptEvalDuration{0}; + qint64 evalCount{0}; + qint64 evalDuration{0}; + }; + + struct GenerateResponse + { + QString response; + QVector context; + }; + + struct ChatResponse + { + QString role; + QString content; + }; + + QString model; + QDateTime createdAt; + std::variant response; + bool done{false}; + QString doneReason; + Metrics metrics; + QString error; + + static OllamaMessage fromJson(const QJsonObject &obj, Type type); + + QString getContent() const; + + bool hasError() const; +}; + +} // namespace QodeAssist::Providers diff --git a/providers/OllamaProvider.cpp b/providers/OllamaProvider.cpp index b852d61..180a000 100644 --- a/providers/OllamaProvider.cpp +++ b/providers/OllamaProvider.cpp @@ -25,6 +25,7 @@ #include #include +#include "OllamaMessage.hpp" #include "logger/Logger.hpp" #include "settings/ChatAssistantSettings.hpp" #include "settings/CodeCompletionSettings.hpp" @@ -87,53 +88,41 @@ void OllamaProvider::prepareRequest(QJsonObject &request, LLMCore::RequestType t bool OllamaProvider::handleResponse(QNetworkReply *reply, QString &accumulatedResponse) { - QString endpoint = reply->url().path(); + const QString endpoint = reply->url().path(); + auto messageType = endpoint == completionEndpoint() ? OllamaMessage::Type::Generate + : OllamaMessage::Type::Chat; - bool isComplete = false; - while (reply->canReadLine()) { - QByteArray line = reply->readLine().trimmed(); - if (line.isEmpty()) { - continue; - } - - QJsonDocument doc = QJsonDocument::fromJson(line); - if (doc.isNull()) { - LOG_MESSAGE("Invalid JSON response from Ollama: " + QString::fromUtf8(line)); - continue; - } - - QJsonObject responseObj = doc.object(); - - if (responseObj.contains("error")) { - QString errorMessage = responseObj["error"].toString(); - LOG_MESSAGE("Error in Ollama response: " + errorMessage); - return false; - } - - if (endpoint == completionEndpoint()) { - if (responseObj.contains("response")) { - QString completion = responseObj["response"].toString(); - accumulatedResponse += completion; + auto processMessage = + [&accumulatedResponse](const QJsonDocument &doc, OllamaMessage::Type messageType) { + if (doc.isNull()) { + LOG_MESSAGE("Invalid JSON response from Ollama"); + return false; } - } else if (endpoint == chatEndpoint()) { - if (responseObj.contains("message")) { - QJsonObject message = responseObj["message"].toObject(); - if (message.contains("content")) { - QString content = message["content"].toString(); - accumulatedResponse += content; - } - } - } else { - LOG_MESSAGE("Unknown endpoint: " + endpoint); - } - if (responseObj.contains("done") && responseObj["done"].toBool()) { - isComplete = true; - break; + auto message = OllamaMessage::fromJson(doc.object(), messageType); + if (message.hasError()) { + LOG_MESSAGE("Error in Ollama response: " + message.error); + return false; + } + + accumulatedResponse += message.getContent(); + return message.done; + }; + + if (reply->canReadLine()) { + while (reply->canReadLine()) { + QByteArray line = reply->readLine().trimmed(); + if (line.isEmpty()) + continue; + + if (processMessage(QJsonDocument::fromJson(line), messageType)) { + return true; + } } + return false; + } else { + return processMessage(QJsonDocument::fromJson(reply->readAll()), messageType); } - - return isComplete; } QList OllamaProvider::getInstalledModels(const QString &url)