From 545b8ed000d098d9aca912f4247803fd842c98dd Mon Sep 17 00:00:00 2001 From: Petr Mironychev <9195189+Palm1r@users.noreply.github.com> Date: Mon, 30 Mar 2026 08:08:49 +0200 Subject: [PATCH] refactor: Adapt provider to clients API --- CMakeLists.txt | 5 - pluginllmcore/Provider.cpp | 29 +- pluginllmcore/Provider.hpp | 25 +- providers/ClaudeProvider.cpp | 429 +++++------------------ providers/ClaudeProvider.hpp | 28 +- providers/GoogleAIProvider.cpp | 422 ++++------------------ providers/GoogleAIProvider.hpp | 32 +- providers/LMStudioProvider.cpp | 314 ++++------------- providers/LMStudioProvider.hpp | 26 +- providers/LlamaCppProvider.cpp | 307 ++++------------ providers/LlamaCppProvider.hpp | 29 +- providers/MistralAIProvider.cpp | 322 ++++------------- providers/MistralAIProvider.hpp | 26 +- providers/OllamaProvider.cpp | 439 ++++------------------- providers/OllamaProvider.hpp | 29 +- providers/OpenAICompatProvider.cpp | 296 ++++------------ providers/OpenAICompatProvider.hpp | 26 +- providers/OpenAIProvider.cpp | 330 ++++------------- providers/OpenAIProvider.hpp | 24 +- providers/OpenAIResponsesProvider.cpp | 487 ++++---------------------- providers/OpenAIResponsesProvider.hpp | 30 +- 21 files changed, 697 insertions(+), 2958 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index f4bfde9..7168c53 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -114,7 +114,6 @@ add_qtc_plugin(QodeAssist providers/OpenAIResponses/ItemTypesReference.hpp providers/OpenAIResponsesRequestBuilder.hpp providers/OpenAIResponsesProvider.hpp providers/OpenAIResponsesProvider.cpp - providers/OpenAIResponsesMessage.hpp providers/OpenAIResponsesMessage.cpp QodeAssist.qrc LSPCompletion.hpp LLMSuggestion.hpp LLMSuggestion.cpp @@ -154,10 +153,6 @@ add_qtc_plugin(QodeAssist tools/FindAndReadFileTool.hpp tools/FindAndReadFileTool.cpp tools/FileSearchUtils.hpp tools/FileSearchUtils.cpp tools/TodoTool.hpp tools/TodoTool.cpp - providers/ClaudeMessage.hpp providers/ClaudeMessage.cpp - providers/OpenAIMessage.hpp providers/OpenAIMessage.cpp - providers/OllamaMessage.hpp providers/OllamaMessage.cpp - providers/GoogleMessage.hpp providers/GoogleMessage.cpp ) get_target_property(QtCreatorCorePath QtCreator::Core LOCATION) diff --git a/pluginllmcore/Provider.cpp b/pluginllmcore/Provider.cpp index 0a9da24..e5274f6 100644 --- a/pluginllmcore/Provider.cpp +++ b/pluginllmcore/Provider.cpp @@ -1,36 +1,9 @@ #include "Provider.hpp" -#include - namespace QodeAssist::PluginLLMCore { Provider::Provider(QObject *parent) : QObject(parent) - , m_httpClient(new HttpClient(this)) -{ - connect(m_httpClient, &HttpClient::dataReceived, this, &Provider::onDataReceived); - connect(m_httpClient, &HttpClient::requestFinished, this, &Provider::onRequestFinished); -} - -void Provider::cancelRequest(const RequestID &requestId) -{ - m_httpClient->cancelRequest(requestId); -} - -HttpClient *Provider::httpClient() const -{ - return m_httpClient; -} - -QJsonObject Provider::parseEventLine(const QString &line) -{ - if (!line.startsWith("data: ")) - return QJsonObject(); - - QString jsonStr = line.mid(6); - - QJsonDocument doc = QJsonDocument::fromJson(jsonStr.toUtf8()); - return doc.object(); -} +{} } // namespace QodeAssist::PluginLLMCore diff --git a/pluginllmcore/Provider.hpp b/pluginllmcore/Provider.hpp index 7154563..2303754 100644 --- a/pluginllmcore/Provider.hpp +++ b/pluginllmcore/Provider.hpp @@ -19,8 +19,6 @@ #pragma once -#include - #include #include #include @@ -28,8 +26,6 @@ #include #include "ContextData.hpp" -#include "DataBuffers.hpp" -#include "HttpClient.hpp" #include "IToolsManager.hpp" #include "PromptTemplate.hpp" #include "RequestType.hpp" @@ -38,7 +34,6 @@ namespace LLMCore { class ToolsManager; } -class QNetworkReply; class QJsonObject; namespace QodeAssist::PluginLLMCore { @@ -77,20 +72,10 @@ public: virtual bool supportThinking() const { return false; }; virtual bool supportImage() const { return false; }; - virtual void cancelRequest(const RequestID &requestId); + virtual void cancelRequest(const RequestID &requestId) = 0; virtual ::LLMCore::ToolsManager *toolsManager() const { return nullptr; } - HttpClient *httpClient() const; - -public slots: - virtual void onDataReceived( - const QodeAssist::PluginLLMCore::RequestID &requestId, const QByteArray &data) - = 0; - virtual void onRequestFinished( - const QodeAssist::PluginLLMCore::RequestID &requestId, std::optional error) - = 0; - signals: void partialResponseReceived( const QodeAssist::PluginLLMCore::RequestID &requestId, const QString &partialText); @@ -109,14 +94,6 @@ signals: const QString &requestId, const QString &thinking, const QString &signature); void redactedThinkingBlockReceived(const QString &requestId, const QString &signature); -protected: - QJsonObject parseEventLine(const QString &line); - - QHash m_dataBuffers; - QHash m_requestUrls; - -private: - HttpClient *m_httpClient; }; } // namespace QodeAssist::PluginLLMCore diff --git a/providers/ClaudeProvider.cpp b/providers/ClaudeProvider.cpp index 67069da..c2a7dcf 100644 --- a/providers/ClaudeProvider.cpp +++ b/providers/ClaudeProvider.cpp @@ -1,4 +1,4 @@ -/* +/* * Copyright (C) 2024-2025 Petr Mironychev * * This file is part of QodeAssist. @@ -22,7 +22,6 @@ #include #include #include -#include #include @@ -39,15 +38,10 @@ namespace QodeAssist::Providers { ClaudeProvider::ClaudeProvider(QObject *parent) : PluginLLMCore::Provider(parent) - , m_client(new ::LLMCore::ClaudeClient(url(), apiKey(), QString(), this)) + , m_client(new ::LLMCore::ClaudeClient( + url(), Settings::providerSettings().claudeApiKey(), QString(), this)) { Tools::registerQodeAssistTools(m_client->tools()); - - connect( - m_client->tools(), - &::LLMCore::ToolsManager::toolExecutionComplete, - this, - &ClaudeProvider::onToolExecutionComplete); } QString ClaudeProvider::name() const @@ -131,11 +125,6 @@ void ClaudeProvider::prepareRequest( } if (isToolsEnabled) { - PluginLLMCore::RunToolsFilter filter = PluginLLMCore::RunToolsFilter::ALL; - if (type == PluginLLMCore::RequestType::QuickRefactoring) { - filter = PluginLLMCore::RunToolsFilter::OnlyRead; - } - auto toolsDefinitions = m_client->tools()->getToolsDefinitions(); if (!toolsDefinitions.isEmpty()) { @@ -147,40 +136,13 @@ void ClaudeProvider::prepareRequest( QFuture> ClaudeProvider::getInstalledModels(const QString &baseUrl) { - QUrl url(baseUrl + "/v1/models"); - QUrlQuery query; - query.addQueryItem("limit", "1000"); - url.setQuery(query); - - QNetworkRequest request(url); - request.setHeader(QNetworkRequest::ContentTypeHeader, "application/json"); - request.setRawHeader("anthropic-version", "2023-06-01"); - - if (!apiKey().isEmpty()) { - request.setRawHeader("x-api-key", apiKey().toUtf8()); - } - - return httpClient()->get(request).then([](const QByteArray &data) { - QList models; - QJsonObject jsonObject = QJsonDocument::fromJson(data).object(); - - if (jsonObject.contains("data")) { - QJsonArray modelArray = jsonObject["data"].toArray(); - for (const QJsonValue &value : modelArray) { - QJsonObject modelObject = value.toObject(); - if (modelObject.contains("id")) { - models.append(modelObject["id"].toString()); - } - } - } - return models; - }).onFailed([](const std::exception &e) { - LOG_MESSAGE(QString("Error fetching Claude models: %1").arg(e.what())); - return QList{}; - }); + m_client->setUrl(baseUrl); + m_client->setApiKey(apiKey()); + return m_client->listModels(); } -QList ClaudeProvider::validateRequest(const QJsonObject &request, PluginLLMCore::TemplateType type) +QList ClaudeProvider::validateRequest( + const QJsonObject &request, PluginLLMCore::TemplateType type) { const auto templateReq = QJsonObject{ {"model", {}}, @@ -222,19 +184,69 @@ PluginLLMCore::ProviderID ClaudeProvider::providerID() const void ClaudeProvider::sendRequest( const PluginLLMCore::RequestID &requestId, const QUrl &url, const QJsonObject &payload) { - if (!m_messages.contains(requestId)) { - m_dataBuffers[requestId].clear(); - } + QUrl baseUrl(url); + baseUrl.setPath(""); + m_client->setUrl(baseUrl.toString()); + m_client->setApiKey(apiKey()); - m_requestUrls[requestId] = url; - m_originalRequests[requestId] = payload; + ::LLMCore::RequestCallbacks callbacks; - QNetworkRequest networkRequest(url); - prepareNetworkRequest(networkRequest); + callbacks.onChunk = [this, requestId](const ::LLMCore::RequestID &, const QString &chunk) { + if (m_awaitingContinuation.remove(requestId)) { + emit continuationStarted(requestId); + } + emit partialResponseReceived(requestId, chunk); + }; - LOG_MESSAGE(QString("ClaudeProvider: Sending request %1 to %2").arg(requestId, url.toString())); + callbacks.onCompleted + = [this, requestId](const ::LLMCore::RequestID &clientId, const QString &fullText) { + emit fullResponseReceived(requestId, fullText); + m_providerToClientIds.remove(requestId); + m_clientToProviderIds.remove(clientId); + m_awaitingContinuation.remove(requestId); + }; - httpClient()->postStreaming(requestId, networkRequest, payload); + callbacks.onFailed + = [this, requestId](const ::LLMCore::RequestID &clientId, const QString &error) { + emit requestFailed(requestId, error); + m_providerToClientIds.remove(requestId); + m_clientToProviderIds.remove(clientId); + m_awaitingContinuation.remove(requestId); + }; + + callbacks.onThinkingBlock = [this, requestId](const ::LLMCore::RequestID &, + const QString &thinking, + const QString &signature) { + if (m_awaitingContinuation.remove(requestId)) { + emit continuationStarted(requestId); + } + if (thinking.isEmpty()) { + emit redactedThinkingBlockReceived(requestId, signature); + } else { + emit thinkingBlockReceived(requestId, thinking, signature); + } + }; + + callbacks.onToolStarted = [this, requestId](const ::LLMCore::RequestID &, + const QString &toolId, + const QString &toolName) { + emit toolExecutionStarted(requestId, toolId, toolName); + m_awaitingContinuation.insert(requestId); + }; + + callbacks.onToolResult = [this, requestId](const ::LLMCore::RequestID &, + const QString &toolId, + const QString &toolName, + const QString &result) { + emit toolExecutionCompleted(requestId, toolId, toolName, result); + }; + + auto clientId = m_client->sendMessage(payload, callbacks); + m_providerToClientIds[requestId] = clientId; + m_clientToProviderIds[clientId] = requestId; + + LOG_MESSAGE(QString("ClaudeProvider: Sending request %1 (client: %2) to %3") + .arg(requestId, clientId, url.toString())); } bool ClaudeProvider::supportsTools() const @@ -242,19 +254,26 @@ bool ClaudeProvider::supportsTools() const return true; } -bool ClaudeProvider::supportThinking() const { +bool ClaudeProvider::supportThinking() const +{ return true; -}; +} -bool ClaudeProvider::supportImage() const { +bool ClaudeProvider::supportImage() const +{ return true; -}; +} void ClaudeProvider::cancelRequest(const PluginLLMCore::RequestID &requestId) { LOG_MESSAGE(QString("ClaudeProvider: Cancelling request %1").arg(requestId)); - PluginLLMCore::Provider::cancelRequest(requestId); - cleanupRequest(requestId); + + if (m_providerToClientIds.contains(requestId)) { + auto clientId = m_providerToClientIds.take(requestId); + m_clientToProviderIds.remove(clientId); + m_client->cancelRequest(clientId); + } + m_awaitingContinuation.remove(requestId); } ::LLMCore::ToolsManager *ClaudeProvider::toolsManager() const @@ -262,292 +281,4 @@ void ClaudeProvider::cancelRequest(const PluginLLMCore::RequestID &requestId) return m_client->tools(); } -void ClaudeProvider::onDataReceived( - const QodeAssist::PluginLLMCore::RequestID &requestId, const QByteArray &data) -{ - PluginLLMCore::DataBuffers &buffers = m_dataBuffers[requestId]; - QStringList lines = buffers.rawStreamBuffer.processData(data); - - for (const QString &line : lines) { - QJsonObject responseObj = parseEventLine(line); - if (responseObj.isEmpty()) - continue; - - processStreamEvent(requestId, responseObj); - } -} - -void ClaudeProvider::onRequestFinished( - const QodeAssist::PluginLLMCore::RequestID &requestId, std::optional error) -{ - if (error) { - LOG_MESSAGE(QString("ClaudeProvider request %1 failed: %2").arg(requestId, *error)); - emit requestFailed(requestId, *error); - cleanupRequest(requestId); - return; - } - - if (m_messages.contains(requestId)) { - ClaudeMessage *message = m_messages[requestId]; - if (message->state() == PluginLLMCore::MessageState::RequiresToolExecution) { - LOG_MESSAGE(QString("Waiting for tools to complete for %1").arg(requestId)); - m_dataBuffers.remove(requestId); - return; - } - } - - if (m_dataBuffers.contains(requestId)) { - const PluginLLMCore::DataBuffers &buffers = m_dataBuffers[requestId]; - if (!buffers.responseContent.isEmpty()) { - LOG_MESSAGE(QString("Emitting full response for %1").arg(requestId)); - emit fullResponseReceived(requestId, buffers.responseContent); - } - } - - cleanupRequest(requestId); -} - -void ClaudeProvider::onToolExecutionComplete( - const QString &requestId, const QHash &toolResults) -{ - if (!m_messages.contains(requestId) || !m_requestUrls.contains(requestId)) { - LOG_MESSAGE(QString("ERROR: Missing data for continuation request %1").arg(requestId)); - cleanupRequest(requestId); - return; - } - - LOG_MESSAGE(QString("Tool execution complete for Claude request %1").arg(requestId)); - - for (auto it = toolResults.begin(); it != toolResults.end(); ++it) { - ClaudeMessage *message = m_messages[requestId]; - auto toolContent = message->getCurrentToolUseContent(); - for (auto tool : toolContent) { - if (tool->id() == it.key()) { - auto toolStringName = m_client->tools()->displayName(tool->name()); - emit toolExecutionCompleted( - requestId, tool->id(), toolStringName, toolResults[tool->id()]); - break; - } - } - } - - ClaudeMessage *message = m_messages[requestId]; - QJsonObject continuationRequest = m_originalRequests[requestId]; - QJsonArray messages = continuationRequest["messages"].toArray(); - - messages.append(message->toProviderFormat()); - - QJsonObject userMessage; - userMessage["role"] = "user"; - userMessage["content"] = message->createToolResultsContent(toolResults); - messages.append(userMessage); - - continuationRequest["messages"] = messages; - - if (continuationRequest.contains("thinking")) { - QJsonObject thinkingObj = continuationRequest["thinking"].toObject(); - LOG_MESSAGE(QString("Thinking mode preserved for continuation: type=%1, budget=%2 tokens") - .arg(thinkingObj["type"].toString()) - .arg(thinkingObj["budget_tokens"].toInt())); - } - - LOG_MESSAGE(QString("Sending continuation request for %1 with %2 tool results") - .arg(requestId) - .arg(toolResults.size())); - - sendRequest(requestId, m_requestUrls[requestId], continuationRequest); -} - -void ClaudeProvider::processStreamEvent(const QString &requestId, const QJsonObject &event) -{ - QString eventType = event["type"].toString(); - - if (eventType == "message_stop") { - return; - } - - ClaudeMessage *message = m_messages.value(requestId); - if (!message) { - if (eventType == "message_start") { - message = new ClaudeMessage(this); - m_messages[requestId] = message; - LOG_MESSAGE(QString("Created NEW ClaudeMessage for request %1").arg(requestId)); - } else { - return; - } - } - - if (eventType == "message_start") { - message->startNewContinuation(); - emit continuationStarted(requestId); - LOG_MESSAGE(QString("Starting NEW continuation for request %1").arg(requestId)); - - } else if (eventType == "content_block_start") { - int index = event["index"].toInt(); - QJsonObject contentBlock = event["content_block"].toObject(); - QString blockType = contentBlock["type"].toString(); - - LOG_MESSAGE( - QString("Adding new content block: type=%1, index=%2").arg(blockType).arg(index)); - - if (blockType == "thinking" || blockType == "redacted_thinking") { - QJsonDocument eventDoc(event); - LOG_MESSAGE(QString("content_block_start event for %1: %2") - .arg(blockType) - .arg(QString::fromUtf8(eventDoc.toJson(QJsonDocument::Compact)))); - } - - message->handleContentBlockStart(index, blockType, contentBlock); - - } else if (eventType == "content_block_delta") { - int index = event["index"].toInt(); - QJsonObject delta = event["delta"].toObject(); - QString deltaType = delta["type"].toString(); - - message->handleContentBlockDelta(index, deltaType, delta); - - if (deltaType == "text_delta") { - QString text = delta["text"].toString(); - PluginLLMCore::DataBuffers &buffers = m_dataBuffers[requestId]; - buffers.responseContent += text; - emit partialResponseReceived(requestId, text); - } else if (deltaType == "signature_delta") { - QString signature = delta["signature"].toString(); - } - - } else if (eventType == "content_block_stop") { - int index = event["index"].toInt(); - - auto allBlocks = message->getCurrentBlocks(); - if (index < allBlocks.size()) { - QString blockType = allBlocks[index]->type(); - if (blockType == "thinking" || blockType == "redacted_thinking") { - QJsonDocument eventDoc(event); - LOG_MESSAGE(QString("content_block_stop event for %1 at index %2: %3") - .arg(blockType) - .arg(index) - .arg(QString::fromUtf8(eventDoc.toJson(QJsonDocument::Compact)))); - } - } - - if (event.contains("content_block")) { - QJsonObject contentBlock = event["content_block"].toObject(); - QString blockType = contentBlock["type"].toString(); - - if (blockType == "thinking") { - QString signature = contentBlock["signature"].toString(); - if (!signature.isEmpty()) { - auto allBlocks = message->getCurrentBlocks(); - if (index < allBlocks.size()) { - if (auto thinkingContent = dynamic_cast<::LLMCore::ThinkingContent *>(allBlocks[index])) { - thinkingContent->setSignature(signature); - LOG_MESSAGE( - QString("Updated thinking block signature from content_block_stop, " - "signature length=%1") - .arg(signature.length())); - } - } - } - } else if (blockType == "redacted_thinking") { - QString signature = contentBlock["signature"].toString(); - if (!signature.isEmpty()) { - auto allBlocks = message->getCurrentBlocks(); - if (index < allBlocks.size()) { - if (auto redactedContent = dynamic_cast<::LLMCore::RedactedThinkingContent *>(allBlocks[index])) { - redactedContent->setSignature(signature); - LOG_MESSAGE( - QString("Updated redacted_thinking block signature from content_block_stop, " - "signature length=%1") - .arg(signature.length())); - } - } - } - } - } - - message->handleContentBlockStop(index); - - auto thinkingBlocks = message->getCurrentThinkingContent(); - for (auto thinkingContent : thinkingBlocks) { - auto allBlocks = message->getCurrentBlocks(); - if (index < allBlocks.size() && allBlocks[index] == thinkingContent) { - emit thinkingBlockReceived( - requestId, thinkingContent->thinking(), thinkingContent->signature()); - LOG_MESSAGE( - QString("Emitted thinking block for request %1, thinking length=%2, signature length=%3") - .arg(requestId) - .arg(thinkingContent->thinking().length()) - .arg(thinkingContent->signature().length())); - break; - } - } - - auto redactedBlocks = message->getCurrentRedactedThinkingContent(); - for (auto redactedContent : redactedBlocks) { - auto allBlocks = message->getCurrentBlocks(); - if (index < allBlocks.size() && allBlocks[index] == redactedContent) { - emit redactedThinkingBlockReceived(requestId, redactedContent->signature()); - LOG_MESSAGE( - QString("Emitted redacted thinking block for request %1, signature length=%2") - .arg(requestId) - .arg(redactedContent->signature().length())); - break; - } - } - - } else if (eventType == "message_delta") { - QJsonObject delta = event["delta"].toObject(); - if (delta.contains("stop_reason")) { - QString stopReason = delta["stop_reason"].toString(); - message->handleStopReason(stopReason); - handleMessageComplete(requestId); - } - } -} - -void ClaudeProvider::handleMessageComplete(const QString &requestId) -{ - if (!m_messages.contains(requestId)) - return; - - ClaudeMessage *message = m_messages[requestId]; - - if (message->state() == PluginLLMCore::MessageState::RequiresToolExecution) { - LOG_MESSAGE(QString("Claude message requires tool execution for %1").arg(requestId)); - - auto toolUseContent = message->getCurrentToolUseContent(); - - if (toolUseContent.isEmpty()) { - LOG_MESSAGE(QString("No tools to execute for %1").arg(requestId)); - return; - } - - for (auto toolContent : toolUseContent) { - auto toolStringName = m_client->tools()->displayName(toolContent->name()); - emit toolExecutionStarted(requestId, toolContent->id(), toolStringName); - - m_client->tools()->executeToolCall( - requestId, toolContent->id(), toolContent->name(), toolContent->input()); - } - - } else { - LOG_MESSAGE(QString("Claude message marked as complete for %1").arg(requestId)); - } -} - -void ClaudeProvider::cleanupRequest(const PluginLLMCore::RequestID &requestId) -{ - LOG_MESSAGE(QString("Cleaning up Claude request %1").arg(requestId)); - - if (m_messages.contains(requestId)) { - ClaudeMessage *message = m_messages.take(requestId); - message->deleteLater(); - } - - m_dataBuffers.remove(requestId); - m_requestUrls.remove(requestId); - m_originalRequests.remove(requestId); - m_client->tools()->cleanupRequest(requestId); -} - } // namespace QodeAssist::Providers diff --git a/providers/ClaudeProvider.hpp b/providers/ClaudeProvider.hpp index e233aec..a986386 100644 --- a/providers/ClaudeProvider.hpp +++ b/providers/ClaudeProvider.hpp @@ -1,4 +1,4 @@ -/* +/* * Copyright (C) 2024-2025 Petr Mironychev * * This file is part of QodeAssist. @@ -19,9 +19,10 @@ #pragma once +#include + #include -#include "ClaudeMessage.hpp" #include namespace QodeAssist::Providers { @@ -57,29 +58,14 @@ public: bool supportThinking() const override; bool supportImage() const override; void cancelRequest(const PluginLLMCore::RequestID &requestId) override; - + ::LLMCore::ToolsManager *toolsManager() const override; -public slots: - void onDataReceived( - const QodeAssist::PluginLLMCore::RequestID &requestId, const QByteArray &data) override; - void onRequestFinished( - const QodeAssist::PluginLLMCore::RequestID &requestId, - std::optional error) override; - -private slots: - void onToolExecutionComplete( - const QString &requestId, const QHash &toolResults); - private: - void processStreamEvent(const QString &requestId, const QJsonObject &event); - void handleMessageComplete(const QString &requestId); - void cleanupRequest(const PluginLLMCore::RequestID &requestId); - - QHash m_messages; - QHash m_requestUrls; - QHash m_originalRequests; ::LLMCore::ClaudeClient *m_client; + QHash m_providerToClientIds; + QHash<::LLMCore::RequestID, PluginLLMCore::RequestID> m_clientToProviderIds; + QSet m_awaitingContinuation; }; } // namespace QodeAssist::Providers diff --git a/providers/GoogleAIProvider.cpp b/providers/GoogleAIProvider.cpp index 1a1f12a..14acb6f 100644 --- a/providers/GoogleAIProvider.cpp +++ b/providers/GoogleAIProvider.cpp @@ -1,4 +1,4 @@ -/* +/* * Copyright (C) 2024-2025 Petr Mironychev * * This file is part of QodeAssist. @@ -42,12 +42,6 @@ GoogleAIProvider::GoogleAIProvider(QObject *parent) , m_client(new ::LLMCore::GoogleAIClient(url(), apiKey(), QString(), this)) { Tools::registerQodeAssistTools(m_client->tools()); - - connect( - m_client->tools(), - &::LLMCore::ToolsManager::toolExecutionComplete, - this, - &GoogleAIProvider::onToolExecutionComplete); } QString GoogleAIProvider::name() const @@ -145,11 +139,6 @@ void GoogleAIProvider::prepareRequest( } if (isToolsEnabled) { - PluginLLMCore::RunToolsFilter filter = PluginLLMCore::RunToolsFilter::ALL; - if (type == PluginLLMCore::RequestType::QuickRefactoring) { - filter = PluginLLMCore::RunToolsFilter::OnlyRead; - } - auto toolsDefinitions = m_client->tools()->getToolsDefinitions(); if (!toolsDefinitions.isEmpty()) { request["tools"] = toolsDefinitions; @@ -158,33 +147,11 @@ void GoogleAIProvider::prepareRequest( } } -QFuture> GoogleAIProvider::getInstalledModels(const QString &url) +QFuture> GoogleAIProvider::getInstalledModels(const QString &baseUrl) { - QNetworkRequest request(QString("%1/models?key=%2").arg(url, apiKey())); - request.setHeader(QNetworkRequest::ContentTypeHeader, "application/json"); - - return httpClient()->get(request).then([](const QByteArray &data) { - QList models; - QJsonObject jsonObject = QJsonDocument::fromJson(data).object(); - - if (jsonObject.contains("models")) { - QJsonArray modelArray = jsonObject["models"].toArray(); - for (const QJsonValue &value : modelArray) { - QJsonObject modelObject = value.toObject(); - if (modelObject.contains("name")) { - QString modelName = modelObject["name"].toString(); - if (modelName.contains("/")) { - modelName = modelName.split("/").last(); - } - models.append(modelName); - } - } - } - return models; - }).onFailed([](const std::exception &e) { - LOG_MESSAGE(QString("Error fetching Google AI models: %1").arg(e.what())); - return QList{}; - }); + m_client->setUrl(baseUrl); + m_client->setApiKey(apiKey()); + return m_client->listModels(); } QList GoogleAIProvider::validateRequest( @@ -233,20 +200,69 @@ PluginLLMCore::ProviderID GoogleAIProvider::providerID() const void GoogleAIProvider::sendRequest( const PluginLLMCore::RequestID &requestId, const QUrl &url, const QJsonObject &payload) { - if (!m_messages.contains(requestId)) { - m_dataBuffers[requestId].clear(); - } + QUrl baseUrl(url); + baseUrl.setPath(""); + m_client->setUrl(baseUrl.toString()); + m_client->setApiKey(apiKey()); - m_requestUrls[requestId] = url; - m_originalRequests[requestId] = payload; + ::LLMCore::RequestCallbacks callbacks; - QNetworkRequest networkRequest(url); - prepareNetworkRequest(networkRequest); + callbacks.onChunk = [this, requestId](const ::LLMCore::RequestID &, const QString &chunk) { + if (m_awaitingContinuation.remove(requestId)) { + emit continuationStarted(requestId); + } + emit partialResponseReceived(requestId, chunk); + }; - LOG_MESSAGE( - QString("GoogleAIProvider: Sending request %1 to %2").arg(requestId, url.toString())); + callbacks.onCompleted + = [this, requestId](const ::LLMCore::RequestID &clientId, const QString &fullText) { + emit fullResponseReceived(requestId, fullText); + m_providerToClientIds.remove(requestId); + m_clientToProviderIds.remove(clientId); + m_awaitingContinuation.remove(requestId); + }; - httpClient()->postStreaming(requestId, networkRequest, payload); + callbacks.onFailed + = [this, requestId](const ::LLMCore::RequestID &clientId, const QString &error) { + emit requestFailed(requestId, error); + m_providerToClientIds.remove(requestId); + m_clientToProviderIds.remove(clientId); + m_awaitingContinuation.remove(requestId); + }; + + callbacks.onThinkingBlock = [this, requestId](const ::LLMCore::RequestID &, + const QString &thinking, + const QString &signature) { + if (m_awaitingContinuation.remove(requestId)) { + emit continuationStarted(requestId); + } + if (thinking.isEmpty()) { + emit redactedThinkingBlockReceived(requestId, signature); + } else { + emit thinkingBlockReceived(requestId, thinking, signature); + } + }; + + callbacks.onToolStarted = [this, requestId](const ::LLMCore::RequestID &, + const QString &toolId, + const QString &toolName) { + emit toolExecutionStarted(requestId, toolId, toolName); + m_awaitingContinuation.insert(requestId); + }; + + callbacks.onToolResult = [this, requestId](const ::LLMCore::RequestID &, + const QString &toolId, + const QString &toolName, + const QString &result) { + emit toolExecutionCompleted(requestId, toolId, toolName, result); + }; + + auto clientId = m_client->sendMessage(payload, callbacks); + m_providerToClientIds[requestId] = clientId; + m_clientToProviderIds[clientId] = requestId; + + LOG_MESSAGE(QString("GoogleAIProvider: Sending request %1 (client: %2) to %3") + .arg(requestId, clientId, url.toString())); } bool GoogleAIProvider::supportsTools() const @@ -267,313 +283,13 @@ bool GoogleAIProvider::supportImage() const void GoogleAIProvider::cancelRequest(const PluginLLMCore::RequestID &requestId) { LOG_MESSAGE(QString("GoogleAIProvider: Cancelling request %1").arg(requestId)); - PluginLLMCore::Provider::cancelRequest(requestId); - cleanupRequest(requestId); -} -void GoogleAIProvider::onDataReceived( - const QodeAssist::PluginLLMCore::RequestID &requestId, const QByteArray &data) -{ - if (data.isEmpty()) { - return; + if (m_providerToClientIds.contains(requestId)) { + auto clientId = m_providerToClientIds.take(requestId); + m_clientToProviderIds.remove(clientId); + m_client->cancelRequest(clientId); } - - QJsonParseError parseError; - QJsonDocument doc = QJsonDocument::fromJson(data, &parseError); - if (!doc.isNull() && doc.isObject()) { - QJsonObject obj = doc.object(); - if (obj.contains("error")) { - QJsonObject error = obj["error"].toObject(); - QString errorMessage = error["message"].toString(); - int errorCode = error["code"].toInt(); - QString fullError - = QString("Google AI API Error %1: %2").arg(errorCode).arg(errorMessage); - - LOG_MESSAGE(fullError); - emit requestFailed(requestId, fullError); - cleanupRequest(requestId); - return; - } - } - - PluginLLMCore::DataBuffers &buffers = m_dataBuffers[requestId]; - QStringList lines = buffers.rawStreamBuffer.processData(data); - - for (const QString &line : lines) { - if (line.trimmed().isEmpty()) { - continue; - } - - QJsonObject chunk = parseEventLine(line); - if (chunk.isEmpty()) - continue; - - processStreamChunk(requestId, chunk); - } -} - -void GoogleAIProvider::onRequestFinished( - const QodeAssist::PluginLLMCore::RequestID &requestId, std::optional error) -{ - if (error) { - LOG_MESSAGE(QString("GoogleAIProvider request %1 failed: %2").arg(requestId, *error)); - emit requestFailed(requestId, *error); - cleanupRequest(requestId); - return; - } - - if (m_failedRequests.contains(requestId)) { - cleanupRequest(requestId); - return; - } - - emitPendingThinkingBlocks(requestId); - - if (m_messages.contains(requestId)) { - GoogleMessage *message = m_messages[requestId]; - - handleMessageComplete(requestId); - - if (message->state() == PluginLLMCore::MessageState::RequiresToolExecution) { - LOG_MESSAGE(QString("Waiting for tools to complete for %1").arg(requestId)); - m_dataBuffers.remove(requestId); - return; - } - } - - if (m_dataBuffers.contains(requestId)) { - const PluginLLMCore::DataBuffers &buffers = m_dataBuffers[requestId]; - if (!buffers.responseContent.isEmpty()) { - emit fullResponseReceived(requestId, buffers.responseContent); - } else { - emit fullResponseReceived(requestId, QString()); - } - } else { - emit fullResponseReceived(requestId, QString()); - } - - cleanupRequest(requestId); -} - -void GoogleAIProvider::onToolExecutionComplete( - const QString &requestId, const QHash &toolResults) -{ - if (!m_messages.contains(requestId) || !m_requestUrls.contains(requestId)) { - LOG_MESSAGE(QString("ERROR: Missing data for continuation request %1").arg(requestId)); - cleanupRequest(requestId); - return; - } - - for (auto it = toolResults.begin(); it != toolResults.end(); ++it) { - GoogleMessage *message = m_messages[requestId]; - auto toolContent = message->getCurrentToolUseContent(); - for (auto tool : toolContent) { - if (tool->id() == it.key()) { - auto toolStringName = m_client->tools()->displayName(tool->name()); - emit toolExecutionCompleted( - requestId, tool->id(), toolStringName, toolResults[tool->id()]); - break; - } - } - } - - GoogleMessage *message = m_messages[requestId]; - QJsonObject continuationRequest = m_originalRequests[requestId]; - QJsonArray contents = continuationRequest["contents"].toArray(); - - contents.append(message->toProviderFormat()); - - QJsonObject userMessage; - userMessage["role"] = "user"; - userMessage["parts"] = message->createToolResultParts(toolResults); - contents.append(userMessage); - - continuationRequest["contents"] = contents; - - sendRequest(requestId, m_requestUrls[requestId], continuationRequest); -} - -void GoogleAIProvider::processStreamChunk(const QString &requestId, const QJsonObject &chunk) -{ - if (!chunk.contains("candidates")) { - return; - } - - GoogleMessage *message = m_messages.value(requestId); - if (!message) { - message = new GoogleMessage(this); - m_messages[requestId] = message; - LOG_MESSAGE(QString("Created NEW GoogleMessage for request %1").arg(requestId)); - - if (m_dataBuffers.contains(requestId)) { - emit continuationStarted(requestId); - LOG_MESSAGE(QString("Starting continuation for request %1").arg(requestId)); - } - } else if ( - m_dataBuffers.contains(requestId) - && message->state() == PluginLLMCore::MessageState::RequiresToolExecution) { - message->startNewContinuation(); - m_emittedThinkingBlocksCount[requestId] = 0; - LOG_MESSAGE(QString("Cleared message state for continuation request %1").arg(requestId)); - } - - QJsonArray candidates = chunk["candidates"].toArray(); - for (const QJsonValue &candidate : candidates) { - QJsonObject candidateObj = candidate.toObject(); - - if (candidateObj.contains("content")) { - QJsonObject content = candidateObj["content"].toObject(); - if (content.contains("parts")) { - QJsonArray parts = content["parts"].toArray(); - for (const QJsonValue &part : parts) { - QJsonObject partObj = part.toObject(); - - if (partObj.contains("text")) { - QString text = partObj["text"].toString(); - bool isThought = partObj.value("thought").toBool(false); - - if (isThought) { - message->handleThoughtDelta(text); - - if (partObj.contains("signature")) { - QString signature = partObj["signature"].toString(); - message->handleThoughtSignature(signature); - } - } else { - emitPendingThinkingBlocks(requestId); - - message->handleContentDelta(text); - - PluginLLMCore::DataBuffers &buffers = m_dataBuffers[requestId]; - buffers.responseContent += text; - emit partialResponseReceived(requestId, text); - } - } - - if (partObj.contains("thoughtSignature")) { - QString signature = partObj["thoughtSignature"].toString(); - message->handleThoughtSignature(signature); - } - - if (partObj.contains("functionCall")) { - emitPendingThinkingBlocks(requestId); - - QJsonObject functionCall = partObj["functionCall"].toObject(); - QString name = functionCall["name"].toString(); - QJsonObject args = functionCall["args"].toObject(); - - message->handleFunctionCallStart(name); - message->handleFunctionCallArgsDelta( - QString::fromUtf8(QJsonDocument(args).toJson(QJsonDocument::Compact))); - message->handleFunctionCallComplete(); - } - } - } - } - - if (candidateObj.contains("finishReason")) { - QString finishReason = candidateObj["finishReason"].toString(); - message->handleFinishReason(finishReason); - - if (message->isErrorFinishReason()) { - QString errorMessage = message->getErrorMessage(); - LOG_MESSAGE(QString("Google AI error: %1").arg(errorMessage)); - m_failedRequests.insert(requestId); - emit requestFailed(requestId, errorMessage); - return; - } - } - } - - if (chunk.contains("usageMetadata")) { - QJsonObject usageMetadata = chunk["usageMetadata"].toObject(); - int thoughtsTokenCount = usageMetadata.value("thoughtsTokenCount").toInt(0); - int candidatesTokenCount = usageMetadata.value("candidatesTokenCount").toInt(0); - int totalTokenCount = usageMetadata.value("totalTokenCount").toInt(0); - - if (totalTokenCount > 0) { - LOG_MESSAGE(QString("Google AI tokens: %1 (thoughts: %2, output: %3)") - .arg(totalTokenCount) - .arg(thoughtsTokenCount) - .arg(candidatesTokenCount)); - } - } -} - -void GoogleAIProvider::emitPendingThinkingBlocks(const QString &requestId) -{ - if (!m_messages.contains(requestId)) - return; - - GoogleMessage *message = m_messages[requestId]; - auto thinkingBlocks = message->getCurrentThinkingContent(); - - if (thinkingBlocks.isEmpty()) - return; - - int alreadyEmitted = m_emittedThinkingBlocksCount.value(requestId, 0); - int totalBlocks = thinkingBlocks.size(); - - for (int i = alreadyEmitted; i < totalBlocks; ++i) { - auto thinkingContent = thinkingBlocks[i]; - - if (thinkingContent->thinking().trimmed().isEmpty()) { - continue; - } - - emit thinkingBlockReceived( - requestId, - thinkingContent->thinking(), - thinkingContent->signature()); - } - - m_emittedThinkingBlocksCount[requestId] = totalBlocks; -} - -void GoogleAIProvider::handleMessageComplete(const QString &requestId) -{ - if (!m_messages.contains(requestId)) - return; - - GoogleMessage *message = m_messages[requestId]; - - if (message->state() == PluginLLMCore::MessageState::RequiresToolExecution) { - LOG_MESSAGE(QString("Google AI message requires tool execution for %1").arg(requestId)); - - auto toolUseContent = message->getCurrentToolUseContent(); - - if (toolUseContent.isEmpty()) { - LOG_MESSAGE(QString("No tools to execute for %1").arg(requestId)); - return; - } - - for (auto toolContent : toolUseContent) { - auto toolStringName = m_client->tools()->displayName(toolContent->name()); - emit toolExecutionStarted(requestId, toolContent->id(), toolStringName); - m_client->tools()->executeToolCall( - requestId, toolContent->id(), toolContent->name(), toolContent->input()); - } - - } else { - LOG_MESSAGE(QString("Google AI message marked as complete for %1").arg(requestId)); - } -} - -void GoogleAIProvider::cleanupRequest(const PluginLLMCore::RequestID &requestId) -{ - LOG_MESSAGE(QString("Cleaning up Google AI request %1").arg(requestId)); - - if (m_messages.contains(requestId)) { - GoogleMessage *message = m_messages.take(requestId); - message->deleteLater(); - } - - m_dataBuffers.remove(requestId); - m_requestUrls.remove(requestId); - m_originalRequests.remove(requestId); - m_emittedThinkingBlocksCount.remove(requestId); - m_failedRequests.remove(requestId); - m_client->tools()->cleanupRequest(requestId); + m_awaitingContinuation.remove(requestId); } ::LLMCore::ToolsManager *GoogleAIProvider::toolsManager() const diff --git a/providers/GoogleAIProvider.hpp b/providers/GoogleAIProvider.hpp index d3cdf2b..d3df995 100644 --- a/providers/GoogleAIProvider.hpp +++ b/providers/GoogleAIProvider.hpp @@ -1,4 +1,4 @@ -/* +/* * Copyright (C) 2024-2025 Petr Mironychev * * This file is part of QodeAssist. @@ -19,8 +19,10 @@ #pragma once -#include "GoogleMessage.hpp" -#include "pluginllmcore/Provider.hpp" +#include + +#include + #include namespace QodeAssist::Providers { @@ -59,29 +61,11 @@ public: ::LLMCore::ToolsManager *toolsManager() const override; -public slots: - void onDataReceived( - const QodeAssist::PluginLLMCore::RequestID &requestId, const QByteArray &data) override; - void onRequestFinished( - const QodeAssist::PluginLLMCore::RequestID &requestId, - std::optional error) override; - -private slots: - void onToolExecutionComplete( - const QString &requestId, const QHash &toolResults); - private: - void processStreamChunk(const QString &requestId, const QJsonObject &chunk); - void handleMessageComplete(const QString &requestId); - void emitPendingThinkingBlocks(const QString &requestId); - void cleanupRequest(const PluginLLMCore::RequestID &requestId); - - QHash m_messages; - QHash m_requestUrls; - QHash m_originalRequests; - QHash m_emittedThinkingBlocksCount; - QSet m_failedRequests; ::LLMCore::GoogleAIClient *m_client; + QHash m_providerToClientIds; + QHash<::LLMCore::RequestID, PluginLLMCore::RequestID> m_clientToProviderIds; + QSet m_awaitingContinuation; }; } // namespace QodeAssist::Providers diff --git a/providers/LMStudioProvider.cpp b/providers/LMStudioProvider.cpp index f72ae25..a3c60f3 100644 --- a/providers/LMStudioProvider.cpp +++ b/providers/LMStudioProvider.cpp @@ -1,4 +1,4 @@ -/* +/* * Copyright (C) 2024-2025 Petr Mironychev * * This file is part of QodeAssist. @@ -41,12 +41,6 @@ LMStudioProvider::LMStudioProvider(QObject *parent) , m_client(new ::LLMCore::OpenAIClient(url(), apiKey(), QString(), this)) { Tools::registerQodeAssistTools(m_client->tools()); - - connect( - m_client->tools(), - &::LLMCore::ToolsManager::toolExecutionComplete, - this, - &LMStudioProvider::onToolExecutionComplete); } QString LMStudioProvider::name() const @@ -76,22 +70,9 @@ bool LMStudioProvider::supportsModelListing() const QFuture> LMStudioProvider::getInstalledModels(const QString &url) { - QNetworkRequest request(QString("%1%2").arg(url, "/v1/models")); - - return httpClient()->get(request).then([](const QByteArray &data) { - QList models; - QJsonObject jsonObject = QJsonDocument::fromJson(data).object(); - QJsonArray modelArray = jsonObject["data"].toArray(); - - for (const QJsonValue &value : modelArray) { - QJsonObject modelObject = value.toObject(); - models.append(modelObject["id"].toString()); - } - return models; - }).onFailed([](const std::exception &e) { - LOG_MESSAGE(QString("Error fetching LMStudio models: %1").arg(e.what())); - return QList{}; - }); + m_client->setUrl(url); + m_client->setApiKey(apiKey()); + return m_client->listModels(); } QList LMStudioProvider::validateRequest( @@ -131,20 +112,69 @@ PluginLLMCore::ProviderID LMStudioProvider::providerID() const void LMStudioProvider::sendRequest( const PluginLLMCore::RequestID &requestId, const QUrl &url, const QJsonObject &payload) { - if (!m_messages.contains(requestId)) { - m_dataBuffers[requestId].clear(); - } + QUrl baseUrl(url); + baseUrl.setPath(""); + m_client->setUrl(baseUrl.toString()); + m_client->setApiKey(apiKey()); - m_requestUrls[requestId] = url; - m_originalRequests[requestId] = payload; + ::LLMCore::RequestCallbacks callbacks; - QNetworkRequest networkRequest(url); - prepareNetworkRequest(networkRequest); + callbacks.onChunk = [this, requestId](const ::LLMCore::RequestID &, const QString &chunk) { + if (m_awaitingContinuation.remove(requestId)) { + emit continuationStarted(requestId); + } + emit partialResponseReceived(requestId, chunk); + }; - LOG_MESSAGE( - QString("LMStudioProvider: Sending request %1 to %2").arg(requestId, url.toString())); + callbacks.onCompleted + = [this, requestId](const ::LLMCore::RequestID &clientId, const QString &fullText) { + emit fullResponseReceived(requestId, fullText); + m_providerToClientIds.remove(requestId); + m_clientToProviderIds.remove(clientId); + m_awaitingContinuation.remove(requestId); + }; - httpClient()->postStreaming(requestId, networkRequest, payload); + callbacks.onFailed + = [this, requestId](const ::LLMCore::RequestID &clientId, const QString &error) { + emit requestFailed(requestId, error); + m_providerToClientIds.remove(requestId); + m_clientToProviderIds.remove(clientId); + m_awaitingContinuation.remove(requestId); + }; + + callbacks.onThinkingBlock = [this, requestId](const ::LLMCore::RequestID &, + const QString &thinking, + const QString &signature) { + if (m_awaitingContinuation.remove(requestId)) { + emit continuationStarted(requestId); + } + if (thinking.isEmpty()) { + emit redactedThinkingBlockReceived(requestId, signature); + } else { + emit thinkingBlockReceived(requestId, thinking, signature); + } + }; + + callbacks.onToolStarted = [this, requestId](const ::LLMCore::RequestID &, + const QString &toolId, + const QString &toolName) { + emit toolExecutionStarted(requestId, toolId, toolName); + m_awaitingContinuation.insert(requestId); + }; + + callbacks.onToolResult = [this, requestId](const ::LLMCore::RequestID &, + const QString &toolId, + const QString &toolName, + const QString &result) { + emit toolExecutionCompleted(requestId, toolId, toolName, result); + }; + + auto clientId = m_client->sendMessage(payload, callbacks); + m_providerToClientIds[requestId] = clientId; + m_clientToProviderIds[clientId] = requestId; + + LOG_MESSAGE(QString("LMStudioProvider: Sending request %1 (client: %2) to %3") + .arg(requestId, clientId, url.toString())); } bool LMStudioProvider::supportsTools() const @@ -160,57 +190,13 @@ bool LMStudioProvider::supportImage() const void LMStudioProvider::cancelRequest(const PluginLLMCore::RequestID &requestId) { LOG_MESSAGE(QString("LMStudioProvider: Cancelling request %1").arg(requestId)); - PluginLLMCore::Provider::cancelRequest(requestId); - cleanupRequest(requestId); -} -void LMStudioProvider::onDataReceived( - const QodeAssist::PluginLLMCore::RequestID &requestId, const QByteArray &data) -{ - PluginLLMCore::DataBuffers &buffers = m_dataBuffers[requestId]; - QStringList lines = buffers.rawStreamBuffer.processData(data); - - for (const QString &line : lines) { - if (line.trimmed().isEmpty() || line == "data: [DONE]") { - continue; - } - - QJsonObject chunk = parseEventLine(line); - if (chunk.isEmpty()) - continue; - - processStreamChunk(requestId, chunk); + if (m_providerToClientIds.contains(requestId)) { + auto clientId = m_providerToClientIds.take(requestId); + m_clientToProviderIds.remove(clientId); + m_client->cancelRequest(clientId); } -} - -void LMStudioProvider::onRequestFinished( - const QodeAssist::PluginLLMCore::RequestID &requestId, std::optional error) -{ - if (error) { - LOG_MESSAGE(QString("LMStudioProvider request %1 failed: %2").arg(requestId, *error)); - emit requestFailed(requestId, *error); - cleanupRequest(requestId); - return; - } - - if (m_messages.contains(requestId)) { - OpenAIMessage *message = m_messages[requestId]; - if (message->state() == PluginLLMCore::MessageState::RequiresToolExecution) { - LOG_MESSAGE(QString("Waiting for tools to complete for %1").arg(requestId)); - m_dataBuffers.remove(requestId); - return; - } - } - - if (m_dataBuffers.contains(requestId)) { - const PluginLLMCore::DataBuffers &buffers = m_dataBuffers[requestId]; - if (!buffers.responseContent.isEmpty()) { - LOG_MESSAGE(QString("Emitting full response for %1").arg(requestId)); - emit fullResponseReceived(requestId, buffers.responseContent); - } - } - - cleanupRequest(requestId); + m_awaitingContinuation.remove(requestId); } void LMStudioProvider::prepareRequest( @@ -250,11 +236,6 @@ void LMStudioProvider::prepareRequest( } if (isToolsEnabled) { - PluginLLMCore::RunToolsFilter filter = PluginLLMCore::RunToolsFilter::ALL; - if (type == PluginLLMCore::RequestType::QuickRefactoring) { - filter = PluginLLMCore::RunToolsFilter::OnlyRead; - } - auto toolsDefinitions = m_client->tools()->getToolsDefinitions(); if (!toolsDefinitions.isEmpty()) { request["tools"] = toolsDefinitions; @@ -263,165 +244,6 @@ void LMStudioProvider::prepareRequest( } } -void LMStudioProvider::onToolExecutionComplete( - const QString &requestId, const QHash &toolResults) -{ - if (!m_messages.contains(requestId) || !m_requestUrls.contains(requestId)) { - LOG_MESSAGE(QString("ERROR: Missing data for continuation request %1").arg(requestId)); - cleanupRequest(requestId); - return; - } - - LOG_MESSAGE(QString("Tool execution complete for LMStudio request %1").arg(requestId)); - - for (auto it = toolResults.begin(); it != toolResults.end(); ++it) { - OpenAIMessage *message = m_messages[requestId]; - auto toolContent = message->getCurrentToolUseContent(); - for (auto tool : toolContent) { - if (tool->id() == it.key()) { - auto toolStringName = m_client->tools()->displayName(tool->name()); - emit toolExecutionCompleted( - requestId, tool->id(), toolStringName, toolResults[tool->id()]); - break; - } - } - } - - OpenAIMessage *message = m_messages[requestId]; - QJsonObject continuationRequest = m_originalRequests[requestId]; - QJsonArray messages = continuationRequest["messages"].toArray(); - - messages.append(message->toProviderFormat()); - - QJsonArray toolResultMessages = message->createToolResultMessages(toolResults); - for (const auto &toolMsg : toolResultMessages) { - messages.append(toolMsg); - } - - continuationRequest["messages"] = messages; - - LOG_MESSAGE(QString("Sending continuation request for %1 with %2 tool results") - .arg(requestId) - .arg(toolResults.size())); - - sendRequest(requestId, m_requestUrls[requestId], continuationRequest); -} - -void LMStudioProvider::processStreamChunk(const QString &requestId, const QJsonObject &chunk) -{ - QJsonArray choices = chunk["choices"].toArray(); - if (choices.isEmpty()) { - return; - } - - QJsonObject choice = choices[0].toObject(); - QJsonObject delta = choice["delta"].toObject(); - QString finishReason = choice["finish_reason"].toString(); - - OpenAIMessage *message = m_messages.value(requestId); - if (!message) { - message = new OpenAIMessage(this); - m_messages[requestId] = message; - LOG_MESSAGE(QString("Created NEW OpenAIMessage for request %1").arg(requestId)); - - if (m_dataBuffers.contains(requestId)) { - emit continuationStarted(requestId); - LOG_MESSAGE(QString("Starting continuation for request %1").arg(requestId)); - } - } else if ( - m_dataBuffers.contains(requestId) - && message->state() == PluginLLMCore::MessageState::RequiresToolExecution) { - message->startNewContinuation(); - emit continuationStarted(requestId); - LOG_MESSAGE(QString("Cleared message state for continuation request %1").arg(requestId)); - } - - if (delta.contains("content") && !delta["content"].isNull()) { - QString content = delta["content"].toString(); - message->handleContentDelta(content); - - PluginLLMCore::DataBuffers &buffers = m_dataBuffers[requestId]; - buffers.responseContent += content; - emit partialResponseReceived(requestId, content); - } - - if (delta.contains("tool_calls")) { - QJsonArray toolCalls = delta["tool_calls"].toArray(); - for (const auto &toolCallValue : toolCalls) { - QJsonObject toolCall = toolCallValue.toObject(); - int index = toolCall["index"].toInt(); - - if (toolCall.contains("id")) { - QString id = toolCall["id"].toString(); - QJsonObject function = toolCall["function"].toObject(); - QString name = function["name"].toString(); - message->handleToolCallStart(index, id, name); - } - - if (toolCall.contains("function")) { - QJsonObject function = toolCall["function"].toObject(); - if (function.contains("arguments")) { - QString args = function["arguments"].toString(); - message->handleToolCallDelta(index, args); - } - } - } - } - - if (!finishReason.isEmpty() && finishReason != "null") { - for (int i = 0; i < 10; ++i) { - message->handleToolCallComplete(i); - } - - message->handleFinishReason(finishReason); - handleMessageComplete(requestId); - } -} - -void LMStudioProvider::handleMessageComplete(const QString &requestId) -{ - if (!m_messages.contains(requestId)) - return; - - OpenAIMessage *message = m_messages[requestId]; - - if (message->state() == PluginLLMCore::MessageState::RequiresToolExecution) { - LOG_MESSAGE(QString("LMStudio message requires tool execution for %1").arg(requestId)); - - auto toolUseContent = message->getCurrentToolUseContent(); - - if (toolUseContent.isEmpty()) { - LOG_MESSAGE(QString("No tools to execute for %1").arg(requestId)); - return; - } - - for (auto toolContent : toolUseContent) { - auto toolStringName = m_client->tools()->displayName(toolContent->name()); - emit toolExecutionStarted(requestId, toolContent->id(), toolStringName); - m_client->tools()->executeToolCall( - requestId, toolContent->id(), toolContent->name(), toolContent->input()); - } - - } else { - LOG_MESSAGE(QString("LMStudio message marked as complete for %1").arg(requestId)); - } -} - -void LMStudioProvider::cleanupRequest(const PluginLLMCore::RequestID &requestId) -{ - LOG_MESSAGE(QString("Cleaning up LMStudio request %1").arg(requestId)); - - if (m_messages.contains(requestId)) { - OpenAIMessage *message = m_messages.take(requestId); - message->deleteLater(); - } - - m_dataBuffers.remove(requestId); - m_requestUrls.remove(requestId); - m_originalRequests.remove(requestId); - m_client->tools()->cleanupRequest(requestId); -} - ::LLMCore::ToolsManager *LMStudioProvider::toolsManager() const { return m_client->tools(); diff --git a/providers/LMStudioProvider.hpp b/providers/LMStudioProvider.hpp index b32f43c..bbb12e1 100644 --- a/providers/LMStudioProvider.hpp +++ b/providers/LMStudioProvider.hpp @@ -1,4 +1,4 @@ -/* +/* * Copyright (C) 2024-2025 Petr Mironychev * * This file is part of QodeAssist. @@ -19,7 +19,8 @@ #pragma once -#include "OpenAIMessage.hpp" +#include + #include #include @@ -58,26 +59,11 @@ public: ::LLMCore::ToolsManager *toolsManager() const override; -public slots: - void onDataReceived( - const QodeAssist::PluginLLMCore::RequestID &requestId, const QByteArray &data) override; - void onRequestFinished( - const QodeAssist::PluginLLMCore::RequestID &requestId, - std::optional error) override; - -private slots: - void onToolExecutionComplete( - const QString &requestId, const QHash &toolResults); - private: - void processStreamChunk(const QString &requestId, const QJsonObject &chunk); - void handleMessageComplete(const QString &requestId); - void cleanupRequest(const PluginLLMCore::RequestID &requestId); - - QHash m_messages; - QHash m_requestUrls; - QHash m_originalRequests; ::LLMCore::OpenAIClient *m_client; + QHash m_providerToClientIds; + QHash<::LLMCore::RequestID, PluginLLMCore::RequestID> m_clientToProviderIds; + QSet m_awaitingContinuation; }; } // namespace QodeAssist::Providers diff --git a/providers/LlamaCppProvider.cpp b/providers/LlamaCppProvider.cpp index 418876b..578a4d0 100644 --- a/providers/LlamaCppProvider.cpp +++ b/providers/LlamaCppProvider.cpp @@ -1,4 +1,4 @@ -/* +/* * Copyright (C) 2024-2025 Petr Mironychev * * This file is part of QodeAssist. @@ -39,12 +39,6 @@ LlamaCppProvider::LlamaCppProvider(QObject *parent) , m_client(new ::LLMCore::LlamaCppClient(url(), apiKey(), QString(), this)) { Tools::registerQodeAssistTools(m_client->tools()); - - connect( - m_client->tools(), - &::LLMCore::ToolsManager::toolExecutionComplete, - this, - &LlamaCppProvider::onToolExecutionComplete); } QString LlamaCppProvider::name() const @@ -109,11 +103,6 @@ void LlamaCppProvider::prepareRequest( } if (isToolsEnabled) { - PluginLLMCore::RunToolsFilter filter = PluginLLMCore::RunToolsFilter::ALL; - if (type == PluginLLMCore::RequestType::QuickRefactoring) { - filter = PluginLLMCore::RunToolsFilter::OnlyRead; - } - auto toolsDefinitions = m_client->tools()->getToolsDefinitions(); if (!toolsDefinitions.isEmpty()) { request["tools"] = toolsDefinitions; @@ -183,20 +172,69 @@ PluginLLMCore::ProviderID LlamaCppProvider::providerID() const void LlamaCppProvider::sendRequest( const PluginLLMCore::RequestID &requestId, const QUrl &url, const QJsonObject &payload) { - if (!m_messages.contains(requestId)) { - m_dataBuffers[requestId].clear(); - } + QUrl baseUrl(url); + baseUrl.setPath(""); + m_client->setUrl(baseUrl.toString()); + m_client->setApiKey(apiKey()); - m_requestUrls[requestId] = url; - m_originalRequests[requestId] = payload; + ::LLMCore::RequestCallbacks callbacks; - QNetworkRequest networkRequest(url); - prepareNetworkRequest(networkRequest); + callbacks.onChunk = [this, requestId](const ::LLMCore::RequestID &, const QString &chunk) { + if (m_awaitingContinuation.remove(requestId)) { + emit continuationStarted(requestId); + } + emit partialResponseReceived(requestId, chunk); + }; - LOG_MESSAGE( - QString("LlamaCppProvider: Sending request %1 to %2").arg(requestId, url.toString())); + callbacks.onCompleted + = [this, requestId](const ::LLMCore::RequestID &clientId, const QString &fullText) { + emit fullResponseReceived(requestId, fullText); + m_providerToClientIds.remove(requestId); + m_clientToProviderIds.remove(clientId); + m_awaitingContinuation.remove(requestId); + }; - httpClient()->postStreaming(requestId, networkRequest, payload); + callbacks.onFailed + = [this, requestId](const ::LLMCore::RequestID &clientId, const QString &error) { + emit requestFailed(requestId, error); + m_providerToClientIds.remove(requestId); + m_clientToProviderIds.remove(clientId); + m_awaitingContinuation.remove(requestId); + }; + + callbacks.onThinkingBlock = [this, requestId](const ::LLMCore::RequestID &, + const QString &thinking, + const QString &signature) { + if (m_awaitingContinuation.remove(requestId)) { + emit continuationStarted(requestId); + } + if (thinking.isEmpty()) { + emit redactedThinkingBlockReceived(requestId, signature); + } else { + emit thinkingBlockReceived(requestId, thinking, signature); + } + }; + + callbacks.onToolStarted = [this, requestId](const ::LLMCore::RequestID &, + const QString &toolId, + const QString &toolName) { + emit toolExecutionStarted(requestId, toolId, toolName); + m_awaitingContinuation.insert(requestId); + }; + + callbacks.onToolResult = [this, requestId](const ::LLMCore::RequestID &, + const QString &toolId, + const QString &toolName, + const QString &result) { + emit toolExecutionCompleted(requestId, toolId, toolName, result); + }; + + auto clientId = m_client->sendMessage(payload, callbacks); + m_providerToClientIds[requestId] = clientId; + m_clientToProviderIds[clientId] = requestId; + + LOG_MESSAGE(QString("LlamaCppProvider: Sending request %1 (client: %2) to %3") + .arg(requestId, clientId, url.toString())); } bool LlamaCppProvider::supportsTools() const @@ -212,228 +250,13 @@ bool LlamaCppProvider::supportImage() const void LlamaCppProvider::cancelRequest(const PluginLLMCore::RequestID &requestId) { LOG_MESSAGE(QString("LlamaCppProvider: Cancelling request %1").arg(requestId)); - PluginLLMCore::Provider::cancelRequest(requestId); - cleanupRequest(requestId); -} -void LlamaCppProvider::onDataReceived( - const QodeAssist::PluginLLMCore::RequestID &requestId, const QByteArray &data) -{ - PluginLLMCore::DataBuffers &buffers = m_dataBuffers[requestId]; - QStringList lines = buffers.rawStreamBuffer.processData(data); - - for (const QString &line : lines) { - if (line.trimmed().isEmpty() || line == "data: [DONE]") { - continue; - } - - QJsonObject chunk = parseEventLine(line); - if (chunk.isEmpty()) - continue; - - if (chunk.contains("content")) { - QString content = chunk["content"].toString(); - if (!content.isEmpty()) { - buffers.responseContent += content; - emit partialResponseReceived(requestId, content); - } - if (chunk["stop"].toBool()) { - emit fullResponseReceived(requestId, buffers.responseContent); - m_dataBuffers.remove(requestId); - } - } else if (chunk.contains("choices")) { - processStreamChunk(requestId, chunk); - } + if (m_providerToClientIds.contains(requestId)) { + auto clientId = m_providerToClientIds.take(requestId); + m_clientToProviderIds.remove(clientId); + m_client->cancelRequest(clientId); } -} - -void LlamaCppProvider::onRequestFinished( - const QodeAssist::PluginLLMCore::RequestID &requestId, std::optional error) -{ - if (error) { - LOG_MESSAGE(QString("LlamaCppProvider request %1 failed: %2").arg(requestId, *error)); - emit requestFailed(requestId, *error); - cleanupRequest(requestId); - return; - } - - if (m_messages.contains(requestId)) { - OpenAIMessage *message = m_messages[requestId]; - if (message->state() == PluginLLMCore::MessageState::RequiresToolExecution) { - LOG_MESSAGE(QString("Waiting for tools to complete for %1").arg(requestId)); - m_dataBuffers.remove(requestId); - return; - } - } - - if (m_dataBuffers.contains(requestId)) { - const PluginLLMCore::DataBuffers &buffers = m_dataBuffers[requestId]; - if (!buffers.responseContent.isEmpty()) { - LOG_MESSAGE(QString("Emitting full response for %1").arg(requestId)); - emit fullResponseReceived(requestId, buffers.responseContent); - } - } - - cleanupRequest(requestId); -} - -void LlamaCppProvider::onToolExecutionComplete( - const QString &requestId, const QHash &toolResults) -{ - if (!m_messages.contains(requestId) || !m_requestUrls.contains(requestId)) { - LOG_MESSAGE(QString("ERROR: Missing data for continuation request %1").arg(requestId)); - cleanupRequest(requestId); - return; - } - - LOG_MESSAGE(QString("Tool execution complete for llama.cpp request %1").arg(requestId)); - - for (auto it = toolResults.begin(); it != toolResults.end(); ++it) { - OpenAIMessage *message = m_messages[requestId]; - auto toolContent = message->getCurrentToolUseContent(); - for (auto tool : toolContent) { - if (tool->id() == it.key()) { - auto toolStringName = m_client->tools()->displayName(tool->name()); - emit toolExecutionCompleted( - requestId, tool->id(), toolStringName, toolResults[tool->id()]); - break; - } - } - } - - OpenAIMessage *message = m_messages[requestId]; - QJsonObject continuationRequest = m_originalRequests[requestId]; - QJsonArray messages = continuationRequest["messages"].toArray(); - - messages.append(message->toProviderFormat()); - - QJsonArray toolResultMessages = message->createToolResultMessages(toolResults); - for (const auto &toolMsg : toolResultMessages) { - messages.append(toolMsg); - } - - continuationRequest["messages"] = messages; - - LOG_MESSAGE(QString("Sending continuation request for %1 with %2 tool results") - .arg(requestId) - .arg(toolResults.size())); - - sendRequest(requestId, m_requestUrls[requestId], continuationRequest); -} - -void LlamaCppProvider::processStreamChunk(const QString &requestId, const QJsonObject &chunk) -{ - QJsonArray choices = chunk["choices"].toArray(); - if (choices.isEmpty()) { - return; - } - - QJsonObject choice = choices[0].toObject(); - QJsonObject delta = choice["delta"].toObject(); - QString finishReason = choice["finish_reason"].toString(); - - OpenAIMessage *message = m_messages.value(requestId); - if (!message) { - message = new OpenAIMessage(this); - m_messages[requestId] = message; - LOG_MESSAGE(QString("Created NEW OpenAIMessage for llama.cpp request %1").arg(requestId)); - - if (m_dataBuffers.contains(requestId)) { - emit continuationStarted(requestId); - LOG_MESSAGE(QString("Starting continuation for request %1").arg(requestId)); - } - } else if ( - m_dataBuffers.contains(requestId) - && message->state() == PluginLLMCore::MessageState::RequiresToolExecution) { - message->startNewContinuation(); - emit continuationStarted(requestId); - LOG_MESSAGE(QString("Cleared message state for continuation request %1").arg(requestId)); - } - - if (delta.contains("content") && !delta["content"].isNull()) { - QString content = delta["content"].toString(); - message->handleContentDelta(content); - - PluginLLMCore::DataBuffers &buffers = m_dataBuffers[requestId]; - buffers.responseContent += content; - emit partialResponseReceived(requestId, content); - } - - if (delta.contains("tool_calls")) { - QJsonArray toolCalls = delta["tool_calls"].toArray(); - for (const auto &toolCallValue : toolCalls) { - QJsonObject toolCall = toolCallValue.toObject(); - int index = toolCall["index"].toInt(); - - if (toolCall.contains("id")) { - QString id = toolCall["id"].toString(); - QJsonObject function = toolCall["function"].toObject(); - QString name = function["name"].toString(); - message->handleToolCallStart(index, id, name); - } - - if (toolCall.contains("function")) { - QJsonObject function = toolCall["function"].toObject(); - if (function.contains("arguments")) { - QString args = function["arguments"].toString(); - message->handleToolCallDelta(index, args); - } - } - } - } - - if (!finishReason.isEmpty() && finishReason != "null") { - for (int i = 0; i < 10; ++i) { - message->handleToolCallComplete(i); - } - - message->handleFinishReason(finishReason); - handleMessageComplete(requestId); - } -} - -void LlamaCppProvider::handleMessageComplete(const QString &requestId) -{ - if (!m_messages.contains(requestId)) - return; - - OpenAIMessage *message = m_messages[requestId]; - - if (message->state() == PluginLLMCore::MessageState::RequiresToolExecution) { - LOG_MESSAGE(QString("llama.cpp message requires tool execution for %1").arg(requestId)); - - auto toolUseContent = message->getCurrentToolUseContent(); - - if (toolUseContent.isEmpty()) { - LOG_MESSAGE(QString("No tools to execute for %1").arg(requestId)); - return; - } - - for (auto toolContent : toolUseContent) { - auto toolStringName = m_client->tools()->displayName(toolContent->name()); - emit toolExecutionStarted(requestId, toolContent->id(), toolStringName); - m_client->tools()->executeToolCall( - requestId, toolContent->id(), toolContent->name(), toolContent->input()); - } - - } else { - LOG_MESSAGE(QString("llama.cpp message marked as complete for %1").arg(requestId)); - } -} - -void LlamaCppProvider::cleanupRequest(const PluginLLMCore::RequestID &requestId) -{ - LOG_MESSAGE(QString("Cleaning up llama.cpp request %1").arg(requestId)); - - if (m_messages.contains(requestId)) { - OpenAIMessage *message = m_messages.take(requestId); - message->deleteLater(); - } - - m_dataBuffers.remove(requestId); - m_requestUrls.remove(requestId); - m_originalRequests.remove(requestId); - m_client->tools()->cleanupRequest(requestId); + m_awaitingContinuation.remove(requestId); } ::LLMCore::ToolsManager *LlamaCppProvider::toolsManager() const diff --git a/providers/LlamaCppProvider.hpp b/providers/LlamaCppProvider.hpp index 2ae9b4f..166800b 100644 --- a/providers/LlamaCppProvider.hpp +++ b/providers/LlamaCppProvider.hpp @@ -1,4 +1,4 @@ -/* +/* * Copyright (C) 2024-2025 Petr Mironychev * * This file is part of QodeAssist. @@ -19,10 +19,12 @@ #pragma once -#include "OpenAIMessage.hpp" -#include +#include + #include +#include + namespace QodeAssist::Providers { class LlamaCppProvider : public PluginLLMCore::Provider @@ -58,26 +60,11 @@ public: ::LLMCore::ToolsManager *toolsManager() const override; -public slots: - void onDataReceived( - const QodeAssist::PluginLLMCore::RequestID &requestId, const QByteArray &data) override; - void onRequestFinished( - const QodeAssist::PluginLLMCore::RequestID &requestId, - std::optional error) override; - -private slots: - void onToolExecutionComplete( - const QString &requestId, const QHash &toolResults); - private: - void processStreamChunk(const QString &requestId, const QJsonObject &chunk); - void handleMessageComplete(const QString &requestId); - void cleanupRequest(const PluginLLMCore::RequestID &requestId); - - QHash m_messages; - QHash m_requestUrls; - QHash m_originalRequests; ::LLMCore::LlamaCppClient *m_client; + QHash m_providerToClientIds; + QHash<::LLMCore::RequestID, PluginLLMCore::RequestID> m_clientToProviderIds; + QSet m_awaitingContinuation; }; } // namespace QodeAssist::Providers diff --git a/providers/MistralAIProvider.cpp b/providers/MistralAIProvider.cpp index c3b333e..3315ed4 100644 --- a/providers/MistralAIProvider.cpp +++ b/providers/MistralAIProvider.cpp @@ -1,4 +1,4 @@ -/* +/* * Copyright (C) 2024-2025 Petr Mironychev * * This file is part of QodeAssist. @@ -40,12 +40,6 @@ MistralAIProvider::MistralAIProvider(QObject *parent) , m_client(new ::LLMCore::OpenAIClient(url(), apiKey(), QString(), this)) { Tools::registerQodeAssistTools(m_client->tools()); - - connect( - m_client->tools(), - &::LLMCore::ToolsManager::toolExecutionComplete, - this, - &MistralAIProvider::onToolExecutionComplete); } QString MistralAIProvider::name() const @@ -75,30 +69,9 @@ bool MistralAIProvider::supportsModelListing() const QFuture> MistralAIProvider::getInstalledModels(const QString &url) { - QNetworkRequest request(QString("%1/v1/models").arg(url)); - request.setHeader(QNetworkRequest::ContentTypeHeader, "application/json"); - if (!apiKey().isEmpty()) { - request.setRawHeader("Authorization", QString("Bearer %1").arg(apiKey()).toUtf8()); - } - - return httpClient()->get(request).then([](const QByteArray &data) { - QList models; - QJsonObject jsonObject = QJsonDocument::fromJson(data).object(); - - if (jsonObject.contains("data") && jsonObject["object"].toString() == "list") { - QJsonArray modelArray = jsonObject["data"].toArray(); - for (const QJsonValue &value : modelArray) { - QJsonObject modelObject = value.toObject(); - if (modelObject.contains("id")) { - models.append(modelObject["id"].toString()); - } - } - } - return models; - }).onFailed([](const std::exception &e) { - LOG_MESSAGE(QString("Error fetching Mistral AI models: %1").arg(e.what())); - return QList{}; - }); + m_client->setUrl(url); + m_client->setApiKey(apiKey()); + return m_client->listModels(); } QList MistralAIProvider::validateRequest( @@ -151,20 +124,69 @@ PluginLLMCore::ProviderID MistralAIProvider::providerID() const void MistralAIProvider::sendRequest( const PluginLLMCore::RequestID &requestId, const QUrl &url, const QJsonObject &payload) { - if (!m_messages.contains(requestId)) { - m_dataBuffers[requestId].clear(); - } + QUrl baseUrl(url); + baseUrl.setPath(""); + m_client->setUrl(baseUrl.toString()); + m_client->setApiKey(apiKey()); - m_requestUrls[requestId] = url; - m_originalRequests[requestId] = payload; + ::LLMCore::RequestCallbacks callbacks; - QNetworkRequest networkRequest(url); - prepareNetworkRequest(networkRequest); + callbacks.onChunk = [this, requestId](const ::LLMCore::RequestID &, const QString &chunk) { + if (m_awaitingContinuation.remove(requestId)) { + emit continuationStarted(requestId); + } + emit partialResponseReceived(requestId, chunk); + }; - LOG_MESSAGE( - QString("MistralAIProvider: Sending request %1 to %2").arg(requestId, url.toString())); + callbacks.onCompleted + = [this, requestId](const ::LLMCore::RequestID &clientId, const QString &fullText) { + emit fullResponseReceived(requestId, fullText); + m_providerToClientIds.remove(requestId); + m_clientToProviderIds.remove(clientId); + m_awaitingContinuation.remove(requestId); + }; - httpClient()->postStreaming(requestId, networkRequest, payload); + callbacks.onFailed + = [this, requestId](const ::LLMCore::RequestID &clientId, const QString &error) { + emit requestFailed(requestId, error); + m_providerToClientIds.remove(requestId); + m_clientToProviderIds.remove(clientId); + m_awaitingContinuation.remove(requestId); + }; + + callbacks.onThinkingBlock = [this, requestId](const ::LLMCore::RequestID &, + const QString &thinking, + const QString &signature) { + if (m_awaitingContinuation.remove(requestId)) { + emit continuationStarted(requestId); + } + if (thinking.isEmpty()) { + emit redactedThinkingBlockReceived(requestId, signature); + } else { + emit thinkingBlockReceived(requestId, thinking, signature); + } + }; + + callbacks.onToolStarted = [this, requestId](const ::LLMCore::RequestID &, + const QString &toolId, + const QString &toolName) { + emit toolExecutionStarted(requestId, toolId, toolName); + m_awaitingContinuation.insert(requestId); + }; + + callbacks.onToolResult = [this, requestId](const ::LLMCore::RequestID &, + const QString &toolId, + const QString &toolName, + const QString &result) { + emit toolExecutionCompleted(requestId, toolId, toolName, result); + }; + + auto clientId = m_client->sendMessage(payload, callbacks); + m_providerToClientIds[requestId] = clientId; + m_clientToProviderIds[clientId] = requestId; + + LOG_MESSAGE(QString("MistralAIProvider: Sending request %1 (client: %2) to %3") + .arg(requestId, clientId, url.toString())); } bool MistralAIProvider::supportsTools() const @@ -180,57 +202,13 @@ bool MistralAIProvider::supportImage() const void MistralAIProvider::cancelRequest(const PluginLLMCore::RequestID &requestId) { LOG_MESSAGE(QString("MistralAIProvider: Cancelling request %1").arg(requestId)); - PluginLLMCore::Provider::cancelRequest(requestId); - cleanupRequest(requestId); -} -void MistralAIProvider::onDataReceived( - const QodeAssist::PluginLLMCore::RequestID &requestId, const QByteArray &data) -{ - PluginLLMCore::DataBuffers &buffers = m_dataBuffers[requestId]; - QStringList lines = buffers.rawStreamBuffer.processData(data); - - for (const QString &line : lines) { - if (line.trimmed().isEmpty() || line == "data: [DONE]") { - continue; - } - - QJsonObject chunk = parseEventLine(line); - if (chunk.isEmpty()) - continue; - - processStreamChunk(requestId, chunk); + if (m_providerToClientIds.contains(requestId)) { + auto clientId = m_providerToClientIds.take(requestId); + m_clientToProviderIds.remove(clientId); + m_client->cancelRequest(clientId); } -} - -void MistralAIProvider::onRequestFinished( - const QodeAssist::PluginLLMCore::RequestID &requestId, std::optional error) -{ - if (error) { - LOG_MESSAGE(QString("MistralAIProvider request %1 failed: %2").arg(requestId, *error)); - emit requestFailed(requestId, *error); - cleanupRequest(requestId); - return; - } - - if (m_messages.contains(requestId)) { - OpenAIMessage *message = m_messages[requestId]; - if (message->state() == PluginLLMCore::MessageState::RequiresToolExecution) { - LOG_MESSAGE(QString("Waiting for tools to complete for %1").arg(requestId)); - m_dataBuffers.remove(requestId); - return; - } - } - - if (m_dataBuffers.contains(requestId)) { - const PluginLLMCore::DataBuffers &buffers = m_dataBuffers[requestId]; - if (!buffers.responseContent.isEmpty()) { - LOG_MESSAGE(QString("Emitting full response for %1").arg(requestId)); - emit fullResponseReceived(requestId, buffers.responseContent); - } - } - - cleanupRequest(requestId); + m_awaitingContinuation.remove(requestId); } void MistralAIProvider::prepareRequest( @@ -270,11 +248,6 @@ void MistralAIProvider::prepareRequest( } if (isToolsEnabled) { - PluginLLMCore::RunToolsFilter filter = PluginLLMCore::RunToolsFilter::ALL; - if (type == PluginLLMCore::RequestType::QuickRefactoring) { - filter = PluginLLMCore::RunToolsFilter::OnlyRead; - } - auto toolsDefinitions = m_client->tools()->getToolsDefinitions(); if (!toolsDefinitions.isEmpty()) { request["tools"] = toolsDefinitions; @@ -283,165 +256,6 @@ void MistralAIProvider::prepareRequest( } } -void MistralAIProvider::onToolExecutionComplete( - const QString &requestId, const QHash &toolResults) -{ - if (!m_messages.contains(requestId) || !m_requestUrls.contains(requestId)) { - LOG_MESSAGE(QString("ERROR: Missing data for continuation request %1").arg(requestId)); - cleanupRequest(requestId); - return; - } - - LOG_MESSAGE(QString("Tool execution complete for Mistral request %1").arg(requestId)); - - for (auto it = toolResults.begin(); it != toolResults.end(); ++it) { - OpenAIMessage *message = m_messages[requestId]; - auto toolContent = message->getCurrentToolUseContent(); - for (auto tool : toolContent) { - if (tool->id() == it.key()) { - auto toolStringName = m_client->tools()->displayName(tool->name()); - emit toolExecutionCompleted( - requestId, tool->id(), toolStringName, toolResults[tool->id()]); - break; - } - } - } - - OpenAIMessage *message = m_messages[requestId]; - QJsonObject continuationRequest = m_originalRequests[requestId]; - QJsonArray messages = continuationRequest["messages"].toArray(); - - messages.append(message->toProviderFormat()); - - QJsonArray toolResultMessages = message->createToolResultMessages(toolResults); - for (const auto &toolMsg : toolResultMessages) { - messages.append(toolMsg); - } - - continuationRequest["messages"] = messages; - - LOG_MESSAGE(QString("Sending continuation request for %1 with %2 tool results") - .arg(requestId) - .arg(toolResults.size())); - - sendRequest(requestId, m_requestUrls[requestId], continuationRequest); -} - -void MistralAIProvider::processStreamChunk(const QString &requestId, const QJsonObject &chunk) -{ - QJsonArray choices = chunk["choices"].toArray(); - if (choices.isEmpty()) { - return; - } - - QJsonObject choice = choices[0].toObject(); - QJsonObject delta = choice["delta"].toObject(); - QString finishReason = choice["finish_reason"].toString(); - - OpenAIMessage *message = m_messages.value(requestId); - if (!message) { - message = new OpenAIMessage(this); - m_messages[requestId] = message; - LOG_MESSAGE(QString("Created NEW OpenAIMessage for Mistral request %1").arg(requestId)); - - if (m_dataBuffers.contains(requestId)) { - emit continuationStarted(requestId); - LOG_MESSAGE(QString("Starting continuation for request %1").arg(requestId)); - } - } else if ( - m_dataBuffers.contains(requestId) - && message->state() == PluginLLMCore::MessageState::RequiresToolExecution) { - message->startNewContinuation(); - emit continuationStarted(requestId); - LOG_MESSAGE(QString("Cleared message state for continuation request %1").arg(requestId)); - } - - if (delta.contains("content") && !delta["content"].isNull()) { - QString content = delta["content"].toString(); - message->handleContentDelta(content); - - PluginLLMCore::DataBuffers &buffers = m_dataBuffers[requestId]; - buffers.responseContent += content; - emit partialResponseReceived(requestId, content); - } - - if (delta.contains("tool_calls")) { - QJsonArray toolCalls = delta["tool_calls"].toArray(); - for (const auto &toolCallValue : toolCalls) { - QJsonObject toolCall = toolCallValue.toObject(); - int index = toolCall["index"].toInt(); - - if (toolCall.contains("id")) { - QString id = toolCall["id"].toString(); - QJsonObject function = toolCall["function"].toObject(); - QString name = function["name"].toString(); - message->handleToolCallStart(index, id, name); - } - - if (toolCall.contains("function")) { - QJsonObject function = toolCall["function"].toObject(); - if (function.contains("arguments")) { - QString args = function["arguments"].toString(); - message->handleToolCallDelta(index, args); - } - } - } - } - - if (!finishReason.isEmpty() && finishReason != "null") { - for (int i = 0; i < 10; ++i) { - message->handleToolCallComplete(i); - } - - message->handleFinishReason(finishReason); - handleMessageComplete(requestId); - } -} - -void MistralAIProvider::handleMessageComplete(const QString &requestId) -{ - if (!m_messages.contains(requestId)) - return; - - OpenAIMessage *message = m_messages[requestId]; - - if (message->state() == PluginLLMCore::MessageState::RequiresToolExecution) { - LOG_MESSAGE(QString("Mistral message requires tool execution for %1").arg(requestId)); - - auto toolUseContent = message->getCurrentToolUseContent(); - - if (toolUseContent.isEmpty()) { - LOG_MESSAGE(QString("No tools to execute for %1").arg(requestId)); - return; - } - - for (auto toolContent : toolUseContent) { - auto toolStringName = m_client->tools()->displayName(toolContent->name()); - emit toolExecutionStarted(requestId, toolContent->id(), toolStringName); - m_client->tools()->executeToolCall( - requestId, toolContent->id(), toolContent->name(), toolContent->input()); - } - - } else { - LOG_MESSAGE(QString("Mistral message marked as complete for %1").arg(requestId)); - } -} - -void MistralAIProvider::cleanupRequest(const PluginLLMCore::RequestID &requestId) -{ - LOG_MESSAGE(QString("Cleaning up Mistral request %1").arg(requestId)); - - if (m_messages.contains(requestId)) { - OpenAIMessage *message = m_messages.take(requestId); - message->deleteLater(); - } - - m_dataBuffers.remove(requestId); - m_requestUrls.remove(requestId); - m_originalRequests.remove(requestId); - m_client->tools()->cleanupRequest(requestId); -} - ::LLMCore::ToolsManager *MistralAIProvider::toolsManager() const { return m_client->tools(); diff --git a/providers/MistralAIProvider.hpp b/providers/MistralAIProvider.hpp index 7d4b1c7..dd954e3 100644 --- a/providers/MistralAIProvider.hpp +++ b/providers/MistralAIProvider.hpp @@ -1,4 +1,4 @@ -/* +/* * Copyright (C) 2024-2025 Petr Mironychev * * This file is part of QodeAssist. @@ -19,7 +19,8 @@ #pragma once -#include "OpenAIMessage.hpp" +#include + #include #include @@ -58,26 +59,11 @@ public: ::LLMCore::ToolsManager *toolsManager() const override; -public slots: - void onDataReceived( - const QodeAssist::PluginLLMCore::RequestID &requestId, const QByteArray &data) override; - void onRequestFinished( - const QodeAssist::PluginLLMCore::RequestID &requestId, - std::optional error) override; - -private slots: - void onToolExecutionComplete( - const QString &requestId, const QHash &toolResults); - private: - void processStreamChunk(const QString &requestId, const QJsonObject &chunk); - void handleMessageComplete(const QString &requestId); - void cleanupRequest(const PluginLLMCore::RequestID &requestId); - - QHash m_messages; - QHash m_requestUrls; - QHash m_originalRequests; ::LLMCore::OpenAIClient *m_client; + QHash m_providerToClientIds; + QHash<::LLMCore::RequestID, PluginLLMCore::RequestID> m_clientToProviderIds; + QSet m_awaitingContinuation; }; } // namespace QodeAssist::Providers diff --git a/providers/OllamaProvider.cpp b/providers/OllamaProvider.cpp index 2b48792..683b629 100644 --- a/providers/OllamaProvider.cpp +++ b/providers/OllamaProvider.cpp @@ -22,7 +22,6 @@ #include #include -#include "tools/ToolsRegistration.hpp" #include #include @@ -33,6 +32,7 @@ #include "settings/QuickRefactorSettings.hpp" #include "settings/GeneralSettings.hpp" #include "settings/ProviderSettings.hpp" +#include "tools/ToolsRegistration.hpp" namespace QodeAssist::Providers { @@ -41,12 +41,6 @@ OllamaProvider::OllamaProvider(QObject *parent) , m_client(new ::LLMCore::OllamaClient(url(), apiKey(), QString(), this)) { Tools::registerQodeAssistTools(m_client->tools()); - - connect( - m_client->tools(), - &::LLMCore::ToolsManager::toolExecutionComplete, - this, - &OllamaProvider::onToolExecutionComplete); } QString OllamaProvider::name() const @@ -119,7 +113,7 @@ void OllamaProvider::prepareRequest( } else if (type == PluginLLMCore::RequestType::QuickRefactoring) { const auto &qrSettings = Settings::quickRefactorSettings(); applySettings(qrSettings); - + if (isThinkingEnabled) { applyThinkingMode(); LOG_MESSAGE(QString("OllamaProvider: Thinking mode enabled for QuickRefactoring")); @@ -135,11 +129,6 @@ void OllamaProvider::prepareRequest( } if (isToolsEnabled) { - PluginLLMCore::RunToolsFilter filter = PluginLLMCore::RunToolsFilter::ALL; - if (type == PluginLLMCore::RequestType::QuickRefactoring) { - filter = PluginLLMCore::RunToolsFilter::OnlyRead; - } - auto toolsDefinitions = m_client->tools()->getToolsDefinitions(); if (!toolsDefinitions.isEmpty()) { request["tools"] = toolsDefinitions; @@ -149,25 +138,11 @@ void OllamaProvider::prepareRequest( } } -QFuture> OllamaProvider::getInstalledModels(const QString &url) +QFuture> OllamaProvider::getInstalledModels(const QString &baseUrl) { - QNetworkRequest request(QString("%1%2").arg(url, "/api/tags")); - prepareNetworkRequest(request); - - return httpClient()->get(request).then([](const QByteArray &data) { - QList models; - QJsonObject jsonObject = QJsonDocument::fromJson(data).object(); - QJsonArray modelArray = jsonObject["models"].toArray(); - - for (const QJsonValue &value : modelArray) { - QJsonObject modelObject = value.toObject(); - models.append(modelObject["name"].toString()); - } - return models; - }).onFailed([](const std::exception &e) { - LOG_MESSAGE(QString("Error fetching models: %1").arg(e.what())); - return QList{}; - }); + m_client->setUrl(baseUrl); + m_client->setApiKey(Settings::providerSettings().ollamaBasicAuthApiKey()); + return m_client->listModels(); } QList OllamaProvider::validateRequest(const QJsonObject &request, PluginLLMCore::TemplateType type) @@ -232,17 +207,69 @@ PluginLLMCore::ProviderID OllamaProvider::providerID() const void OllamaProvider::sendRequest( const PluginLLMCore::RequestID &requestId, const QUrl &url, const QJsonObject &payload) { - m_dataBuffers[requestId].clear(); + QUrl baseUrl(url); + baseUrl.setPath(""); + m_client->setUrl(baseUrl.toString()); + m_client->setApiKey(apiKey()); - m_requestUrls[requestId] = url; - m_originalRequests[requestId] = payload; + ::LLMCore::RequestCallbacks callbacks; - QNetworkRequest networkRequest(url); - prepareNetworkRequest(networkRequest); + callbacks.onChunk = [this, requestId](const ::LLMCore::RequestID &, const QString &chunk) { + if (m_awaitingContinuation.remove(requestId)) { + emit continuationStarted(requestId); + } + emit partialResponseReceived(requestId, chunk); + }; - LOG_MESSAGE(QString("OllamaProvider: Sending request %1 to %2").arg(requestId, url.toString())); + callbacks.onCompleted + = [this, requestId](const ::LLMCore::RequestID &clientId, const QString &fullText) { + emit fullResponseReceived(requestId, fullText); + m_providerToClientIds.remove(requestId); + m_clientToProviderIds.remove(clientId); + m_awaitingContinuation.remove(requestId); + }; - httpClient()->postStreaming(requestId, networkRequest, payload); + callbacks.onFailed + = [this, requestId](const ::LLMCore::RequestID &clientId, const QString &error) { + emit requestFailed(requestId, error); + m_providerToClientIds.remove(requestId); + m_clientToProviderIds.remove(clientId); + m_awaitingContinuation.remove(requestId); + }; + + callbacks.onThinkingBlock = [this, requestId](const ::LLMCore::RequestID &, + const QString &thinking, + const QString &signature) { + if (m_awaitingContinuation.remove(requestId)) { + emit continuationStarted(requestId); + } + if (thinking.isEmpty()) { + emit redactedThinkingBlockReceived(requestId, signature); + } else { + emit thinkingBlockReceived(requestId, thinking, signature); + } + }; + + callbacks.onToolStarted = [this, requestId](const ::LLMCore::RequestID &, + const QString &toolId, + const QString &toolName) { + emit toolExecutionStarted(requestId, toolId, toolName); + m_awaitingContinuation.insert(requestId); + }; + + callbacks.onToolResult = [this, requestId](const ::LLMCore::RequestID &, + const QString &toolId, + const QString &toolName, + const QString &result) { + emit toolExecutionCompleted(requestId, toolId, toolName, result); + }; + + auto clientId = m_client->sendMessage(payload, callbacks); + m_providerToClientIds[requestId] = clientId; + m_clientToProviderIds[clientId] = requestId; + + LOG_MESSAGE(QString("OllamaProvider: Sending request %1 (client: %2) to %3") + .arg(requestId, clientId, url.toString())); } bool OllamaProvider::supportsTools() const @@ -263,339 +290,13 @@ bool OllamaProvider::supportThinking() const void OllamaProvider::cancelRequest(const PluginLLMCore::RequestID &requestId) { LOG_MESSAGE(QString("OllamaProvider: Cancelling request %1").arg(requestId)); - PluginLLMCore::Provider::cancelRequest(requestId); - cleanupRequest(requestId); -} -void OllamaProvider::onDataReceived( - const QodeAssist::PluginLLMCore::RequestID &requestId, const QByteArray &data) -{ - PluginLLMCore::DataBuffers &buffers = m_dataBuffers[requestId]; - QStringList lines = buffers.rawStreamBuffer.processData(data); - - if (data.isEmpty()) { - return; + if (m_providerToClientIds.contains(requestId)) { + auto clientId = m_providerToClientIds.take(requestId); + m_clientToProviderIds.remove(clientId); + m_client->cancelRequest(clientId); } - - for (const QString &line : lines) { - if (line.trimmed().isEmpty()) { - continue; - } - - QJsonParseError error; - QJsonDocument doc = QJsonDocument::fromJson(line.toUtf8(), &error); - if (doc.isNull()) { - LOG_MESSAGE(QString("Failed to parse JSON: %1").arg(error.errorString())); - continue; - } - - QJsonObject obj = doc.object(); - - if (obj.contains("error") && !obj["error"].toString().isEmpty()) { - LOG_MESSAGE("Error in Ollama response: " + obj["error"].toString()); - continue; - } - - processStreamData(requestId, obj); - } -} - -void OllamaProvider::onRequestFinished( - const QodeAssist::PluginLLMCore::RequestID &requestId, std::optional error) -{ - if (error) { - LOG_MESSAGE(QString("OllamaProvider request %1 failed: %2").arg(requestId, *error)); - emit requestFailed(requestId, *error); - cleanupRequest(requestId); - return; - } - - if (m_messages.contains(requestId)) { - OllamaMessage *message = m_messages[requestId]; - if (message->state() == PluginLLMCore::MessageState::RequiresToolExecution) { - LOG_MESSAGE(QString("Waiting for tools to complete for %1").arg(requestId)); - return; - } - } - - QString finalText; - if (m_messages.contains(requestId)) { - OllamaMessage *message = m_messages[requestId]; - - for (auto block : message->currentBlocks()) { - if (auto textContent = qobject_cast(block)) { - finalText += textContent->text(); - } - } - - if (!finalText.isEmpty()) { - LOG_MESSAGE(QString("Emitting full response for %1, length=%2") - .arg(requestId) - .arg(finalText.length())); - emit fullResponseReceived(requestId, finalText); - } - } - - cleanupRequest(requestId); -} - -void OllamaProvider::onToolExecutionComplete( - const QString &requestId, const QHash &toolResults) -{ - if (!m_messages.contains(requestId)) { - LOG_MESSAGE(QString("ERROR: No message found for request %1").arg(requestId)); - cleanupRequest(requestId); - return; - } - - if (!m_requestUrls.contains(requestId) || !m_originalRequests.contains(requestId)) { - LOG_MESSAGE(QString("ERROR: Missing data for continuation request %1").arg(requestId)); - cleanupRequest(requestId); - return; - } - - LOG_MESSAGE(QString("Tool execution complete for Ollama request %1").arg(requestId)); - - OllamaMessage *message = m_messages[requestId]; - - for (auto it = toolResults.begin(); it != toolResults.end(); ++it) { - auto toolContent = message->getCurrentToolUseContent(); - for (auto tool : toolContent) { - if (tool->id() == it.key()) { - auto toolStringName = m_client->tools()->displayName(tool->name()); - emit toolExecutionCompleted(requestId, tool->id(), toolStringName, it.value()); - break; - } - } - } - - QJsonObject continuationRequest = m_originalRequests[requestId]; - QJsonArray messages = continuationRequest["messages"].toArray(); - - QJsonObject assistantMessage = message->toProviderFormat(); - messages.append(assistantMessage); - - LOG_MESSAGE(QString("Assistant message with tool_calls:\n%1") - .arg( - QString::fromUtf8( - QJsonDocument(assistantMessage).toJson(QJsonDocument::Indented)))); - - QJsonArray toolResultMessages = message->createToolResultMessages(toolResults); - for (const auto &toolMsg : toolResultMessages) { - messages.append(toolMsg); - LOG_MESSAGE(QString("Tool result message:\n%1") - .arg( - QString::fromUtf8( - QJsonDocument(toolMsg.toObject()).toJson(QJsonDocument::Indented)))); - } - - continuationRequest["messages"] = messages; - - LOG_MESSAGE(QString("Sending continuation request for %1 with %2 tool results") - .arg(requestId) - .arg(toolResults.size())); - - sendRequest(requestId, m_requestUrls[requestId], continuationRequest); -} - -void OllamaProvider::processStreamData(const QString &requestId, const QJsonObject &data) -{ - OllamaMessage *message = m_messages.value(requestId); - if (!message) { - message = new OllamaMessage(this); - m_messages[requestId] = message; - LOG_MESSAGE(QString("Created NEW OllamaMessage for request %1").arg(requestId)); - - if (m_dataBuffers.contains(requestId)) { - emit continuationStarted(requestId); - LOG_MESSAGE(QString("Starting continuation for request %1").arg(requestId)); - } - } else if ( - m_dataBuffers.contains(requestId) - && message->state() == PluginLLMCore::MessageState::RequiresToolExecution) { - message->startNewContinuation(); - emit continuationStarted(requestId); - LOG_MESSAGE(QString("Cleared message state for continuation request %1").arg(requestId)); - } - - if (data.contains("thinking")) { - QString thinkingDelta = data["thinking"].toString(); - if (!thinkingDelta.isEmpty()) { - message->handleThinkingDelta(thinkingDelta); - LOG_MESSAGE(QString("OllamaProvider: Received thinking delta, length=%1") - .arg(thinkingDelta.length())); - } - } - - if (data.contains("message")) { - QJsonObject messageObj = data["message"].toObject(); - - if (messageObj.contains("thinking")) { - QString thinkingDelta = messageObj["thinking"].toString(); - if (!thinkingDelta.isEmpty()) { - message->handleThinkingDelta(thinkingDelta); - - if (!m_thinkingStarted.contains(requestId)) { - auto thinkingBlocks = message->getCurrentThinkingContent(); - if (!thinkingBlocks.isEmpty() && thinkingBlocks.first()) { - QString currentThinking = thinkingBlocks.first()->thinking(); - QString displayThinking = currentThinking.length() > 50 - ? QString("%1...").arg(currentThinking.left(50)) - : currentThinking; - - emit thinkingBlockReceived(requestId, displayThinking, ""); - m_thinkingStarted.insert(requestId); - } - } - } - } - - if (messageObj.contains("content")) { - QString content = messageObj["content"].toString(); - if (!content.isEmpty()) { - emitThinkingBlocks(requestId, message); - - message->handleContentDelta(content); - - bool hasTextContent = false; - for (auto block : message->currentBlocks()) { - if (qobject_cast(block)) { - hasTextContent = true; - break; - } - } - - if (hasTextContent) { - PluginLLMCore::DataBuffers &buffers = m_dataBuffers[requestId]; - buffers.responseContent += content; - emit partialResponseReceived(requestId, content); - } - } - } - - if (messageObj.contains("tool_calls")) { - QJsonArray toolCalls = messageObj["tool_calls"].toArray(); - LOG_MESSAGE( - QString("OllamaProvider: Found %1 structured tool calls").arg(toolCalls.size())); - for (const auto &toolCallValue : toolCalls) { - message->handleToolCall(toolCallValue.toObject()); - } - } - } - else if (data.contains("response")) { - QString content = data["response"].toString(); - if (!content.isEmpty()) { - message->handleContentDelta(content); - - bool hasTextContent = false; - for (auto block : message->currentBlocks()) { - if (qobject_cast(block)) { - hasTextContent = true; - break; - } - } - - if (hasTextContent) { - PluginLLMCore::DataBuffers &buffers = m_dataBuffers[requestId]; - buffers.responseContent += content; - emit partialResponseReceived(requestId, content); - } - } - } - - if (data["done"].toBool()) { - if (data.contains("signature")) { - QString signature = data["signature"].toString(); - message->handleThinkingComplete(signature); - LOG_MESSAGE(QString("OllamaProvider: Set thinking signature, length=%1") - .arg(signature.length())); - } - - message->handleDone(true); - handleMessageComplete(requestId); - } -} - -void OllamaProvider::handleMessageComplete(const QString &requestId) -{ - if (!m_messages.contains(requestId)) - return; - - OllamaMessage *message = m_messages[requestId]; - - emitThinkingBlocks(requestId, message); - - if (message->state() == PluginLLMCore::MessageState::RequiresToolExecution) { - LOG_MESSAGE(QString("Ollama message requires tool execution for %1").arg(requestId)); - - auto toolUseContent = message->getCurrentToolUseContent(); - - if (toolUseContent.isEmpty()) { - LOG_MESSAGE( - QString("WARNING: No tools to execute for %1 despite RequiresToolExecution state") - .arg(requestId)); - return; - } - - for (auto toolContent : toolUseContent) { - auto toolStringName = m_client->tools()->displayName(toolContent->name()); - emit toolExecutionStarted(requestId, toolContent->id(), toolStringName); - - LOG_MESSAGE( - QString("Executing tool: name=%1, id=%2, input=%3") - .arg(toolContent->name()) - .arg(toolContent->id()) - .arg( - QString::fromUtf8( - QJsonDocument(toolContent->input()).toJson(QJsonDocument::Compact)))); - - m_client->tools()->executeToolCall( - requestId, toolContent->id(), toolContent->name(), toolContent->input()); - } - - } else { - LOG_MESSAGE(QString("Ollama message marked as complete for %1").arg(requestId)); - } -} - -void OllamaProvider::cleanupRequest(const PluginLLMCore::RequestID &requestId) -{ - LOG_MESSAGE(QString("Cleaning up Ollama request %1").arg(requestId)); - - if (m_messages.contains(requestId)) { - auto msg = m_messages.take(requestId); - msg->deleteLater(); - } - - m_dataBuffers.remove(requestId); - m_requestUrls.remove(requestId); - m_originalRequests.remove(requestId); - m_thinkingEmitted.remove(requestId); - m_thinkingStarted.remove(requestId); - m_client->tools()->cleanupRequest(requestId); -} - -void OllamaProvider::emitThinkingBlocks(const QString &requestId, OllamaMessage *message) -{ - if (!message || m_thinkingEmitted.contains(requestId)) { - return; - } - - auto thinkingBlocks = message->getCurrentThinkingContent(); - if (thinkingBlocks.isEmpty()) { - return; - } - - for (auto thinkingContent : thinkingBlocks) { - emit thinkingBlockReceived( - requestId, thinkingContent->thinking(), thinkingContent->signature()); - LOG_MESSAGE(QString("Emitted thinking block for request %1, thinking length=%2, signature " - "length=%3") - .arg(requestId) - .arg(thinkingContent->thinking().length()) - .arg(thinkingContent->signature().length())); - } - m_thinkingEmitted.insert(requestId); + m_awaitingContinuation.remove(requestId); } ::LLMCore::ToolsManager *OllamaProvider::toolsManager() const diff --git a/providers/OllamaProvider.hpp b/providers/OllamaProvider.hpp index 3459553..bcf9e22 100644 --- a/providers/OllamaProvider.hpp +++ b/providers/OllamaProvider.hpp @@ -1,4 +1,4 @@ -/* +/* * Copyright (C) 2024-2025 Petr Mironychev * * This file is part of QodeAssist. @@ -19,9 +19,10 @@ #pragma once +#include + #include -#include "OllamaMessage.hpp" #include namespace QodeAssist::Providers { @@ -60,29 +61,11 @@ public: ::LLMCore::ToolsManager *toolsManager() const override; -public slots: - void onDataReceived( - const QodeAssist::PluginLLMCore::RequestID &requestId, const QByteArray &data) override; - void onRequestFinished( - const QodeAssist::PluginLLMCore::RequestID &requestId, - std::optional error) override; - -private slots: - void onToolExecutionComplete( - const QString &requestId, const QHash &toolResults); - private: - void processStreamData(const QString &requestId, const QJsonObject &data); - void handleMessageComplete(const QString &requestId); - void cleanupRequest(const PluginLLMCore::RequestID &requestId); - void emitThinkingBlocks(const QString &requestId, OllamaMessage *message); - - QHash m_messages; - QHash m_requestUrls; - QHash m_originalRequests; - QSet m_thinkingEmitted; - QSet m_thinkingStarted; ::LLMCore::OllamaClient *m_client; + QHash m_providerToClientIds; + QHash<::LLMCore::RequestID, PluginLLMCore::RequestID> m_clientToProviderIds; + QSet m_awaitingContinuation; }; } // namespace QodeAssist::Providers diff --git a/providers/OpenAICompatProvider.cpp b/providers/OpenAICompatProvider.cpp index e74835e..caa4a1c 100644 --- a/providers/OpenAICompatProvider.cpp +++ b/providers/OpenAICompatProvider.cpp @@ -1,4 +1,4 @@ -/* +/* * Copyright (C) 2024-2025 Petr Mironychev * * This file is part of QodeAssist. @@ -32,7 +32,6 @@ #include #include #include -#include namespace QodeAssist::Providers { @@ -41,12 +40,6 @@ OpenAICompatProvider::OpenAICompatProvider(QObject *parent) , m_client(new ::LLMCore::OpenAIClient(url(), apiKey(), QString(), this)) { Tools::registerQodeAssistTools(m_client->tools()); - - connect( - m_client->tools(), - &::LLMCore::ToolsManager::toolExecutionComplete, - this, - &OpenAICompatProvider::onToolExecutionComplete); } QString OpenAICompatProvider::name() const @@ -111,11 +104,6 @@ void OpenAICompatProvider::prepareRequest( } if (isToolsEnabled) { - PluginLLMCore::RunToolsFilter filter = PluginLLMCore::RunToolsFilter::ALL; - if (type == PluginLLMCore::RequestType::QuickRefactoring) { - filter = PluginLLMCore::RunToolsFilter::OnlyRead; - } - auto toolsDefinitions = m_client->tools()->getToolsDefinitions(); if (!toolsDefinitions.isEmpty()) { request["tools"] = toolsDefinitions; @@ -171,20 +159,69 @@ PluginLLMCore::ProviderID OpenAICompatProvider::providerID() const void OpenAICompatProvider::sendRequest( const PluginLLMCore::RequestID &requestId, const QUrl &url, const QJsonObject &payload) { - if (!m_messages.contains(requestId)) { - m_dataBuffers[requestId].clear(); - } + QUrl baseUrl(url); + baseUrl.setPath(""); + m_client->setUrl(baseUrl.toString()); + m_client->setApiKey(apiKey()); - m_requestUrls[requestId] = url; - m_originalRequests[requestId] = payload; + ::LLMCore::RequestCallbacks callbacks; - QNetworkRequest networkRequest(url); - prepareNetworkRequest(networkRequest); + callbacks.onChunk = [this, requestId](const ::LLMCore::RequestID &, const QString &chunk) { + if (m_awaitingContinuation.remove(requestId)) { + emit continuationStarted(requestId); + } + emit partialResponseReceived(requestId, chunk); + }; - LOG_MESSAGE( - QString("OpenAICompatProvider: Sending request %1 to %2").arg(requestId, url.toString())); + callbacks.onCompleted + = [this, requestId](const ::LLMCore::RequestID &clientId, const QString &fullText) { + emit fullResponseReceived(requestId, fullText); + m_providerToClientIds.remove(requestId); + m_clientToProviderIds.remove(clientId); + m_awaitingContinuation.remove(requestId); + }; - httpClient()->postStreaming(requestId, networkRequest, payload); + callbacks.onFailed + = [this, requestId](const ::LLMCore::RequestID &clientId, const QString &error) { + emit requestFailed(requestId, error); + m_providerToClientIds.remove(requestId); + m_clientToProviderIds.remove(clientId); + m_awaitingContinuation.remove(requestId); + }; + + callbacks.onThinkingBlock = [this, requestId](const ::LLMCore::RequestID &, + const QString &thinking, + const QString &signature) { + if (m_awaitingContinuation.remove(requestId)) { + emit continuationStarted(requestId); + } + if (thinking.isEmpty()) { + emit redactedThinkingBlockReceived(requestId, signature); + } else { + emit thinkingBlockReceived(requestId, thinking, signature); + } + }; + + callbacks.onToolStarted = [this, requestId](const ::LLMCore::RequestID &, + const QString &toolId, + const QString &toolName) { + emit toolExecutionStarted(requestId, toolId, toolName); + m_awaitingContinuation.insert(requestId); + }; + + callbacks.onToolResult = [this, requestId](const ::LLMCore::RequestID &, + const QString &toolId, + const QString &toolName, + const QString &result) { + emit toolExecutionCompleted(requestId, toolId, toolName, result); + }; + + auto clientId = m_client->sendMessage(payload, callbacks); + m_providerToClientIds[requestId] = clientId; + m_clientToProviderIds[clientId] = requestId; + + LOG_MESSAGE(QString("OpenAICompatProvider: Sending request %1 (client: %2) to %3") + .arg(requestId, clientId, url.toString())); } bool OpenAICompatProvider::supportsTools() const @@ -200,216 +237,13 @@ bool OpenAICompatProvider::supportImage() const void OpenAICompatProvider::cancelRequest(const PluginLLMCore::RequestID &requestId) { LOG_MESSAGE(QString("OpenAICompatProvider: Cancelling request %1").arg(requestId)); - PluginLLMCore::Provider::cancelRequest(requestId); - cleanupRequest(requestId); -} -void OpenAICompatProvider::onDataReceived( - const QodeAssist::PluginLLMCore::RequestID &requestId, const QByteArray &data) -{ - PluginLLMCore::DataBuffers &buffers = m_dataBuffers[requestId]; - QStringList lines = buffers.rawStreamBuffer.processData(data); - - for (const QString &line : lines) { - if (line.trimmed().isEmpty() || line == "data: [DONE]") { - continue; - } - - QJsonObject chunk = parseEventLine(line); - if (chunk.isEmpty()) - continue; - - processStreamChunk(requestId, chunk); + if (m_providerToClientIds.contains(requestId)) { + auto clientId = m_providerToClientIds.take(requestId); + m_clientToProviderIds.remove(clientId); + m_client->cancelRequest(clientId); } -} - -void OpenAICompatProvider::onRequestFinished( - const QodeAssist::PluginLLMCore::RequestID &requestId, std::optional error) -{ - if (error) { - LOG_MESSAGE(QString("OpenAICompatProvider request %1 failed: %2").arg(requestId, *error)); - emit requestFailed(requestId, *error); - cleanupRequest(requestId); - return; - } - - if (m_messages.contains(requestId)) { - OpenAIMessage *message = m_messages[requestId]; - if (message->state() == PluginLLMCore::MessageState::RequiresToolExecution) { - LOG_MESSAGE(QString("Waiting for tools to complete for %1").arg(requestId)); - m_dataBuffers.remove(requestId); - return; - } - } - - if (m_dataBuffers.contains(requestId)) { - const PluginLLMCore::DataBuffers &buffers = m_dataBuffers[requestId]; - if (!buffers.responseContent.isEmpty()) { - LOG_MESSAGE(QString("Emitting full response for %1").arg(requestId)); - emit fullResponseReceived(requestId, buffers.responseContent); - } - } - - cleanupRequest(requestId); -} - -void OpenAICompatProvider::onToolExecutionComplete( - const QString &requestId, const QHash &toolResults) -{ - if (!m_messages.contains(requestId) || !m_requestUrls.contains(requestId)) { - LOG_MESSAGE(QString("ERROR: Missing data for continuation request %1").arg(requestId)); - cleanupRequest(requestId); - return; - } - - LOG_MESSAGE(QString("Tool execution complete for OpenAICompat request %1").arg(requestId)); - - for (auto it = toolResults.begin(); it != toolResults.end(); ++it) { - OpenAIMessage *message = m_messages[requestId]; - auto toolContent = message->getCurrentToolUseContent(); - for (auto tool : toolContent) { - if (tool->id() == it.key()) { - auto toolStringName = m_client->tools()->displayName(tool->name()); - emit toolExecutionCompleted( - requestId, tool->id(), toolStringName, toolResults[tool->id()]); - break; - } - } - } - - OpenAIMessage *message = m_messages[requestId]; - QJsonObject continuationRequest = m_originalRequests[requestId]; - QJsonArray messages = continuationRequest["messages"].toArray(); - - messages.append(message->toProviderFormat()); - - QJsonArray toolResultMessages = message->createToolResultMessages(toolResults); - for (const auto &toolMsg : toolResultMessages) { - messages.append(toolMsg); - } - - continuationRequest["messages"] = messages; - - LOG_MESSAGE(QString("Sending continuation request for %1 with %2 tool results") - .arg(requestId) - .arg(toolResults.size())); - - sendRequest(requestId, m_requestUrls[requestId], continuationRequest); -} - -void OpenAICompatProvider::processStreamChunk(const QString &requestId, const QJsonObject &chunk) -{ - QJsonArray choices = chunk["choices"].toArray(); - if (choices.isEmpty()) { - return; - } - - QJsonObject choice = choices[0].toObject(); - QJsonObject delta = choice["delta"].toObject(); - QString finishReason = choice["finish_reason"].toString(); - - OpenAIMessage *message = m_messages.value(requestId); - if (!message) { - message = new OpenAIMessage(this); - m_messages[requestId] = message; - LOG_MESSAGE(QString("Created NEW OpenAIMessage for request %1").arg(requestId)); - - if (m_dataBuffers.contains(requestId)) { - emit continuationStarted(requestId); - LOG_MESSAGE(QString("Starting continuation for request %1").arg(requestId)); - } - } else if ( - m_dataBuffers.contains(requestId) - && message->state() == PluginLLMCore::MessageState::RequiresToolExecution) { - message->startNewContinuation(); - emit continuationStarted(requestId); - LOG_MESSAGE(QString("Cleared message state for continuation request %1").arg(requestId)); - } - - if (delta.contains("content") && !delta["content"].isNull()) { - QString content = delta["content"].toString(); - message->handleContentDelta(content); - - PluginLLMCore::DataBuffers &buffers = m_dataBuffers[requestId]; - buffers.responseContent += content; - emit partialResponseReceived(requestId, content); - } - - if (delta.contains("tool_calls")) { - QJsonArray toolCalls = delta["tool_calls"].toArray(); - for (const auto &toolCallValue : toolCalls) { - QJsonObject toolCall = toolCallValue.toObject(); - int index = toolCall["index"].toInt(); - - if (toolCall.contains("id")) { - QString id = toolCall["id"].toString(); - QJsonObject function = toolCall["function"].toObject(); - QString name = function["name"].toString(); - message->handleToolCallStart(index, id, name); - } - - if (toolCall.contains("function")) { - QJsonObject function = toolCall["function"].toObject(); - if (function.contains("arguments")) { - QString args = function["arguments"].toString(); - message->handleToolCallDelta(index, args); - } - } - } - } - - if (!finishReason.isEmpty() && finishReason != "null") { - for (int i = 0; i < 10; ++i) { - message->handleToolCallComplete(i); - } - - message->handleFinishReason(finishReason); - handleMessageComplete(requestId); - } -} - -void OpenAICompatProvider::handleMessageComplete(const QString &requestId) -{ - if (!m_messages.contains(requestId)) - return; - - OpenAIMessage *message = m_messages[requestId]; - - if (message->state() == PluginLLMCore::MessageState::RequiresToolExecution) { - LOG_MESSAGE(QString("OpenAICompat message requires tool execution for %1").arg(requestId)); - - auto toolUseContent = message->getCurrentToolUseContent(); - - if (toolUseContent.isEmpty()) { - LOG_MESSAGE(QString("No tools to execute for %1").arg(requestId)); - return; - } - - for (auto toolContent : toolUseContent) { - auto toolStringName = m_client->tools()->displayName(toolContent->name()); - emit toolExecutionStarted(requestId, toolContent->id(), toolStringName); - m_client->tools()->executeToolCall( - requestId, toolContent->id(), toolContent->name(), toolContent->input()); - } - - } else { - LOG_MESSAGE(QString("OpenAICompat message marked as complete for %1").arg(requestId)); - } -} - -void OpenAICompatProvider::cleanupRequest(const PluginLLMCore::RequestID &requestId) -{ - LOG_MESSAGE(QString("Cleaning up OpenAICompat request %1").arg(requestId)); - - if (m_messages.contains(requestId)) { - OpenAIMessage *message = m_messages.take(requestId); - message->deleteLater(); - } - - m_dataBuffers.remove(requestId); - m_requestUrls.remove(requestId); - m_originalRequests.remove(requestId); - m_client->tools()->cleanupRequest(requestId); + m_awaitingContinuation.remove(requestId); } ::LLMCore::ToolsManager *OpenAICompatProvider::toolsManager() const diff --git a/providers/OpenAICompatProvider.hpp b/providers/OpenAICompatProvider.hpp index b779f00..539fe96 100644 --- a/providers/OpenAICompatProvider.hpp +++ b/providers/OpenAICompatProvider.hpp @@ -1,4 +1,4 @@ -/* +/* * Copyright (C) 2024-2025 Petr Mironychev * * This file is part of QodeAssist. @@ -19,7 +19,8 @@ #pragma once -#include "OpenAIMessage.hpp" +#include + #include #include @@ -58,26 +59,11 @@ public: ::LLMCore::ToolsManager *toolsManager() const override; -public slots: - void onDataReceived( - const QodeAssist::PluginLLMCore::RequestID &requestId, const QByteArray &data) override; - void onRequestFinished( - const QodeAssist::PluginLLMCore::RequestID &requestId, - std::optional error) override; - -private slots: - void onToolExecutionComplete( - const QString &requestId, const QHash &toolResults); - private: - void processStreamChunk(const QString &requestId, const QJsonObject &chunk); - void handleMessageComplete(const QString &requestId); - void cleanupRequest(const PluginLLMCore::RequestID &requestId); - - QHash m_messages; - QHash m_requestUrls; - QHash m_originalRequests; ::LLMCore::OpenAIClient *m_client; + QHash m_providerToClientIds; + QHash<::LLMCore::RequestID, PluginLLMCore::RequestID> m_clientToProviderIds; + QSet m_awaitingContinuation; }; } // namespace QodeAssist::Providers diff --git a/providers/OpenAIProvider.cpp b/providers/OpenAIProvider.cpp index 955817f..d3fa411 100644 --- a/providers/OpenAIProvider.cpp +++ b/providers/OpenAIProvider.cpp @@ -40,12 +40,6 @@ OpenAIProvider::OpenAIProvider(QObject *parent) , m_client(new ::LLMCore::OpenAIClient(url(), apiKey(), QString(), this)) { Tools::registerQodeAssistTools(m_client->tools()); - - connect( - m_client->tools(), - &::LLMCore::ToolsManager::toolExecutionComplete, - this, - &OpenAIProvider::onToolExecutionComplete); } QString OpenAIProvider::name() const @@ -129,11 +123,6 @@ void OpenAIProvider::prepareRequest( } if (isToolsEnabled) { - PluginLLMCore::RunToolsFilter filter = PluginLLMCore::RunToolsFilter::ALL; - if (type == PluginLLMCore::RequestType::QuickRefactoring) { - filter = PluginLLMCore::RunToolsFilter::OnlyRead; - } - auto toolsDefinitions = m_client->tools()->getToolsDefinitions(); if (!toolsDefinitions.isEmpty()) { request["tools"] = toolsDefinitions; @@ -142,36 +131,20 @@ void OpenAIProvider::prepareRequest( } } -QFuture> OpenAIProvider::getInstalledModels(const QString &url) +QFuture> OpenAIProvider::getInstalledModels(const QString &baseUrl) { - QNetworkRequest request(QString("%1/v1/models").arg(url)); - request.setHeader(QNetworkRequest::ContentTypeHeader, "application/json"); - if (!apiKey().isEmpty()) { - request.setRawHeader("Authorization", QString("Bearer %1").arg(apiKey()).toUtf8()); - } - - return httpClient()->get(request).then([](const QByteArray &data) { - QList models; - QJsonObject jsonObject = QJsonDocument::fromJson(data).object(); - - if (jsonObject.contains("data")) { - QJsonArray modelArray = jsonObject["data"].toArray(); - for (const QJsonValue &value : modelArray) { - QJsonObject modelObject = value.toObject(); - if (modelObject.contains("id")) { - QString modelId = modelObject["id"].toString(); - if (!modelId.contains("dall-e") && !modelId.contains("whisper") - && !modelId.contains("tts") && !modelId.contains("davinci") - && !modelId.contains("babbage") && !modelId.contains("omni")) { - models.append(modelId); - } - } + m_client->setUrl(baseUrl); + m_client->setApiKey(apiKey()); + return m_client->listModels().then([](const QList &allModels) { + QList filtered; + for (const QString &modelId : allModels) { + if (!modelId.contains("dall-e") && !modelId.contains("whisper") + && !modelId.contains("tts") && !modelId.contains("davinci") + && !modelId.contains("babbage") && !modelId.contains("omni")) { + filtered.append(modelId); } } - return models; - }).onFailed([](const std::exception &e) { - LOG_MESSAGE(QString("Error fetching OpenAI models: %1").arg(e.what())); - return QList{}; + return filtered; }); } @@ -216,19 +189,69 @@ PluginLLMCore::ProviderID OpenAIProvider::providerID() const void OpenAIProvider::sendRequest( const PluginLLMCore::RequestID &requestId, const QUrl &url, const QJsonObject &payload) { - if (!m_messages.contains(requestId)) { - m_dataBuffers[requestId].clear(); - } + QUrl baseUrl(url); + baseUrl.setPath(""); + m_client->setUrl(baseUrl.toString()); + m_client->setApiKey(apiKey()); - m_requestUrls[requestId] = url; - m_originalRequests[requestId] = payload; + ::LLMCore::RequestCallbacks callbacks; - QNetworkRequest networkRequest(url); - prepareNetworkRequest(networkRequest); + callbacks.onChunk = [this, requestId](const ::LLMCore::RequestID &, const QString &chunk) { + if (m_awaitingContinuation.remove(requestId)) { + emit continuationStarted(requestId); + } + emit partialResponseReceived(requestId, chunk); + }; - LOG_MESSAGE(QString("OpenAIProvider: Sending request %1 to %2").arg(requestId, url.toString())); + callbacks.onCompleted + = [this, requestId](const ::LLMCore::RequestID &clientId, const QString &fullText) { + emit fullResponseReceived(requestId, fullText); + m_providerToClientIds.remove(requestId); + m_clientToProviderIds.remove(clientId); + m_awaitingContinuation.remove(requestId); + }; - httpClient()->postStreaming(requestId, networkRequest, payload); + callbacks.onFailed + = [this, requestId](const ::LLMCore::RequestID &clientId, const QString &error) { + emit requestFailed(requestId, error); + m_providerToClientIds.remove(requestId); + m_clientToProviderIds.remove(clientId); + m_awaitingContinuation.remove(requestId); + }; + + callbacks.onThinkingBlock = [this, requestId](const ::LLMCore::RequestID &, + const QString &thinking, + const QString &signature) { + if (m_awaitingContinuation.remove(requestId)) { + emit continuationStarted(requestId); + } + if (thinking.isEmpty()) { + emit redactedThinkingBlockReceived(requestId, signature); + } else { + emit thinkingBlockReceived(requestId, thinking, signature); + } + }; + + callbacks.onToolStarted = [this, requestId](const ::LLMCore::RequestID &, + const QString &toolId, + const QString &toolName) { + emit toolExecutionStarted(requestId, toolId, toolName); + m_awaitingContinuation.insert(requestId); + }; + + callbacks.onToolResult = [this, requestId](const ::LLMCore::RequestID &, + const QString &toolId, + const QString &toolName, + const QString &result) { + emit toolExecutionCompleted(requestId, toolId, toolName, result); + }; + + auto clientId = m_client->sendMessage(payload, callbacks); + m_providerToClientIds[requestId] = clientId; + m_clientToProviderIds[clientId] = requestId; + + LOG_MESSAGE(QString("OpenAIProvider: Sending request %1 (client: %2) to %3") + .arg(requestId, clientId, url.toString())); } bool OpenAIProvider::supportsTools() const @@ -244,216 +267,13 @@ bool OpenAIProvider::supportImage() const void OpenAIProvider::cancelRequest(const PluginLLMCore::RequestID &requestId) { LOG_MESSAGE(QString("OpenAIProvider: Cancelling request %1").arg(requestId)); - PluginLLMCore::Provider::cancelRequest(requestId); - cleanupRequest(requestId); -} -void OpenAIProvider::onDataReceived( - const QodeAssist::PluginLLMCore::RequestID &requestId, const QByteArray &data) -{ - PluginLLMCore::DataBuffers &buffers = m_dataBuffers[requestId]; - QStringList lines = buffers.rawStreamBuffer.processData(data); - - for (const QString &line : lines) { - if (line.trimmed().isEmpty() || line == "data: [DONE]") { - continue; - } - - QJsonObject chunk = parseEventLine(line); - if (chunk.isEmpty()) - continue; - - processStreamChunk(requestId, chunk); + if (m_providerToClientIds.contains(requestId)) { + auto clientId = m_providerToClientIds.take(requestId); + m_clientToProviderIds.remove(clientId); + m_client->cancelRequest(clientId); } -} - -void OpenAIProvider::onRequestFinished( - const QodeAssist::PluginLLMCore::RequestID &requestId, std::optional error) -{ - if (error) { - LOG_MESSAGE(QString("OpenAIProvider request %1 failed: %2").arg(requestId, *error)); - emit requestFailed(requestId, *error); - cleanupRequest(requestId); - return; - } - - if (m_messages.contains(requestId)) { - OpenAIMessage *message = m_messages[requestId]; - if (message->state() == PluginLLMCore::MessageState::RequiresToolExecution) { - LOG_MESSAGE(QString("Waiting for tools to complete for %1").arg(requestId)); - m_dataBuffers.remove(requestId); - return; - } - } - - if (m_dataBuffers.contains(requestId)) { - const PluginLLMCore::DataBuffers &buffers = m_dataBuffers[requestId]; - if (!buffers.responseContent.isEmpty()) { - LOG_MESSAGE(QString("Emitting full response for %1").arg(requestId)); - emit fullResponseReceived(requestId, buffers.responseContent); - } - } - - cleanupRequest(requestId); -} - -void OpenAIProvider::onToolExecutionComplete( - const QString &requestId, const QHash &toolResults) -{ - if (!m_messages.contains(requestId) || !m_requestUrls.contains(requestId)) { - LOG_MESSAGE(QString("ERROR: Missing data for continuation request %1").arg(requestId)); - cleanupRequest(requestId); - return; - } - - LOG_MESSAGE(QString("Tool execution complete for OpenAI request %1").arg(requestId)); - - for (auto it = toolResults.begin(); it != toolResults.end(); ++it) { - OpenAIMessage *message = m_messages[requestId]; - auto toolContent = message->getCurrentToolUseContent(); - for (auto tool : toolContent) { - if (tool->id() == it.key()) { - auto toolStringName = m_client->tools()->displayName(tool->name()); - emit toolExecutionCompleted( - requestId, tool->id(), toolStringName, toolResults[tool->id()]); - break; - } - } - } - - OpenAIMessage *message = m_messages[requestId]; - QJsonObject continuationRequest = m_originalRequests[requestId]; - QJsonArray messages = continuationRequest["messages"].toArray(); - - messages.append(message->toProviderFormat()); - - QJsonArray toolResultMessages = message->createToolResultMessages(toolResults); - for (const auto &toolMsg : toolResultMessages) { - messages.append(toolMsg); - } - - continuationRequest["messages"] = messages; - - LOG_MESSAGE(QString("Sending continuation request for %1 with %2 tool results") - .arg(requestId) - .arg(toolResults.size())); - - sendRequest(requestId, m_requestUrls[requestId], continuationRequest); -} - -void OpenAIProvider::processStreamChunk(const QString &requestId, const QJsonObject &chunk) -{ - QJsonArray choices = chunk["choices"].toArray(); - if (choices.isEmpty()) { - return; - } - - QJsonObject choice = choices[0].toObject(); - QJsonObject delta = choice["delta"].toObject(); - QString finishReason = choice["finish_reason"].toString(); - - OpenAIMessage *message = m_messages.value(requestId); - if (!message) { - message = new OpenAIMessage(this); - m_messages[requestId] = message; - LOG_MESSAGE(QString("Created NEW OpenAIAPIMessage for request %1").arg(requestId)); - - if (m_dataBuffers.contains(requestId)) { - emit continuationStarted(requestId); - LOG_MESSAGE(QString("Starting continuation for request %1").arg(requestId)); - } - } else if ( - m_dataBuffers.contains(requestId) - && message->state() == PluginLLMCore::MessageState::RequiresToolExecution) { - message->startNewContinuation(); - emit continuationStarted(requestId); - LOG_MESSAGE(QString("Cleared message state for continuation request %1").arg(requestId)); - } - - if (delta.contains("content") && !delta["content"].isNull()) { - QString content = delta["content"].toString(); - message->handleContentDelta(content); - - PluginLLMCore::DataBuffers &buffers = m_dataBuffers[requestId]; - buffers.responseContent += content; - emit partialResponseReceived(requestId, content); - } - - if (delta.contains("tool_calls")) { - QJsonArray toolCalls = delta["tool_calls"].toArray(); - for (const auto &toolCallValue : toolCalls) { - QJsonObject toolCall = toolCallValue.toObject(); - int index = toolCall["index"].toInt(); - - if (toolCall.contains("id")) { - QString id = toolCall["id"].toString(); - QJsonObject function = toolCall["function"].toObject(); - QString name = function["name"].toString(); - message->handleToolCallStart(index, id, name); - } - - if (toolCall.contains("function")) { - QJsonObject function = toolCall["function"].toObject(); - if (function.contains("arguments")) { - QString args = function["arguments"].toString(); - message->handleToolCallDelta(index, args); - } - } - } - } - - if (!finishReason.isEmpty() && finishReason != "null") { - for (int i = 0; i < 10; ++i) { - message->handleToolCallComplete(i); - } - - message->handleFinishReason(finishReason); - handleMessageComplete(requestId); - } -} - -void OpenAIProvider::handleMessageComplete(const QString &requestId) -{ - if (!m_messages.contains(requestId)) - return; - - OpenAIMessage *message = m_messages[requestId]; - - if (message->state() == PluginLLMCore::MessageState::RequiresToolExecution) { - LOG_MESSAGE(QString("OpenAI message requires tool execution for %1").arg(requestId)); - - auto toolUseContent = message->getCurrentToolUseContent(); - - if (toolUseContent.isEmpty()) { - LOG_MESSAGE(QString("No tools to execute for %1").arg(requestId)); - return; - } - - for (auto toolContent : toolUseContent) { - auto toolStringName = m_client->tools()->displayName(toolContent->name()); - emit toolExecutionStarted(requestId, toolContent->id(), toolStringName); - m_client->tools()->executeToolCall( - requestId, toolContent->id(), toolContent->name(), toolContent->input()); - } - - } else { - LOG_MESSAGE(QString("OpenAI message marked as complete for %1").arg(requestId)); - } -} - -void OpenAIProvider::cleanupRequest(const PluginLLMCore::RequestID &requestId) -{ - LOG_MESSAGE(QString("Cleaning up OpenAI request %1").arg(requestId)); - - if (m_messages.contains(requestId)) { - OpenAIMessage *message = m_messages.take(requestId); - message->deleteLater(); - } - - m_dataBuffers.remove(requestId); - m_requestUrls.remove(requestId); - m_originalRequests.remove(requestId); - m_client->tools()->cleanupRequest(requestId); + m_awaitingContinuation.remove(requestId); } ::LLMCore::ToolsManager *OpenAIProvider::toolsManager() const diff --git a/providers/OpenAIProvider.hpp b/providers/OpenAIProvider.hpp index 91c34d7..3dedee4 100644 --- a/providers/OpenAIProvider.hpp +++ b/providers/OpenAIProvider.hpp @@ -19,7 +19,8 @@ #pragma once -#include "OpenAIMessage.hpp" +#include + #include #include @@ -58,26 +59,11 @@ public: ::LLMCore::ToolsManager *toolsManager() const override; -public slots: - void onDataReceived( - const QodeAssist::PluginLLMCore::RequestID &requestId, const QByteArray &data) override; - void onRequestFinished( - const QodeAssist::PluginLLMCore::RequestID &requestId, - std::optional error) override; - -private slots: - void onToolExecutionComplete( - const QString &requestId, const QHash &toolResults); - private: - void processStreamChunk(const QString &requestId, const QJsonObject &chunk); - void handleMessageComplete(const QString &requestId); - void cleanupRequest(const PluginLLMCore::RequestID &requestId); - - QHash m_messages; - QHash m_requestUrls; - QHash m_originalRequests; ::LLMCore::OpenAIClient *m_client; + QHash m_providerToClientIds; + QHash<::LLMCore::RequestID, PluginLLMCore::RequestID> m_clientToProviderIds; + QSet m_awaitingContinuation; }; } // namespace QodeAssist::Providers diff --git a/providers/OpenAIResponsesProvider.cpp b/providers/OpenAIResponsesProvider.cpp index 03dfbb6..acb5782 100644 --- a/providers/OpenAIResponsesProvider.cpp +++ b/providers/OpenAIResponsesProvider.cpp @@ -1,4 +1,4 @@ -/* +/* * Copyright (C) 2024-2025 Petr Mironychev * * This file is part of QodeAssist. @@ -19,7 +19,6 @@ #include "OpenAIResponsesProvider.hpp" #include -#include "OpenAIResponses/ResponseObject.hpp" #include "tools/ToolsRegistration.hpp" #include "pluginllmcore/ValidationUtils.hpp" @@ -41,12 +40,6 @@ OpenAIResponsesProvider::OpenAIResponsesProvider(QObject *parent) , m_client(new ::LLMCore::OpenAIResponsesClient(url(), apiKey(), QString(), this)) { Tools::registerQodeAssistTools(m_client->tools()); - - connect( - m_client->tools(), - &::LLMCore::ToolsManager::toolExecutionComplete, - this, - &OpenAIResponsesProvider::onToolExecutionComplete); } QString OpenAIResponsesProvider::name() const @@ -101,7 +94,7 @@ void OpenAIResponsesProvider::prepareRequest( if (effortStr.isEmpty()) { effortStr = "medium"; } - + QJsonObject reasoning; reasoning["effort"] = effortStr; request["reasoning"] = reasoning; @@ -132,10 +125,6 @@ void OpenAIResponsesProvider::prepareRequest( } if (isToolsEnabled) { - const PluginLLMCore::RunToolsFilter filter = (type == PluginLLMCore::RequestType::QuickRefactoring) - ? PluginLLMCore::RunToolsFilter::OnlyRead - : PluginLLMCore::RunToolsFilter::ALL; - const auto toolsDefinitions = m_client->tools()->getToolsDefinitions(); if (!toolsDefinitions.isEmpty()) { @@ -160,43 +149,22 @@ void OpenAIResponsesProvider::prepareRequest( request["stream"] = true; } -QFuture> OpenAIResponsesProvider::getInstalledModels(const QString &url) +QFuture> OpenAIResponsesProvider::getInstalledModels(const QString &baseUrl) { - QNetworkRequest request(QString("%1/v1/models").arg(url)); - request.setHeader(QNetworkRequest::ContentTypeHeader, "application/json"); - if (!apiKey().isEmpty()) { - request.setRawHeader("Authorization", QString("Bearer %1").arg(apiKey()).toUtf8()); - } - - return httpClient()->get(request).then([](const QByteArray &data) { - QList models; - const QJsonObject jsonObject = QJsonDocument::fromJson(data).object(); - - if (jsonObject.contains("data")) { - const QJsonArray modelArray = jsonObject["data"].toArray(); - models.reserve(modelArray.size()); - - static const QStringList modelPrefixes = {"gpt-5", "o1", "o2", "o3", "o4"}; - - for (const QJsonValue &value : modelArray) { - const QJsonObject modelObject = value.toObject(); - if (!modelObject.contains("id")) { - continue; - } - - const QString modelId = modelObject["id"].toString(); - for (const QString &prefix : modelPrefixes) { - if (modelId.contains(prefix)) { - models.append(modelId); - break; - } + m_client->setUrl(baseUrl); + m_client->setApiKey(apiKey()); + return m_client->listModels().then([](const QList &models) { + QList filtered; + static const QStringList modelPrefixes = {"gpt-5", "o1", "o2", "o3", "o4"}; + for (const QString &modelId : models) { + for (const QString &prefix : modelPrefixes) { + if (modelId.contains(prefix)) { + filtered.append(modelId); + break; } } } - return models; - }).onFailed([](const std::exception &e) { - LOG_MESSAGE(QString("Error fetching OpenAI models: %1").arg(e.what())); - return QList{}; + return filtered; }); } @@ -262,17 +230,66 @@ PluginLLMCore::ProviderID OpenAIResponsesProvider::providerID() const void OpenAIResponsesProvider::sendRequest( const PluginLLMCore::RequestID &requestId, const QUrl &url, const QJsonObject &payload) { - if (!m_messages.contains(requestId)) { - m_dataBuffers[requestId].clear(); - } + QUrl baseUrl(url); + baseUrl.setPath(""); + m_client->setUrl(baseUrl.toString()); + m_client->setApiKey(apiKey()); - m_requestUrls[requestId] = url; - m_originalRequests[requestId] = payload; + ::LLMCore::RequestCallbacks callbacks; - QNetworkRequest networkRequest(url); - prepareNetworkRequest(networkRequest); + callbacks.onChunk = [this, requestId](const ::LLMCore::RequestID &, const QString &chunk) { + if (m_awaitingContinuation.remove(requestId)) + emit continuationStarted(requestId); + emit partialResponseReceived(requestId, chunk); + }; - httpClient()->postStreaming(requestId, networkRequest, payload); + callbacks.onCompleted + = [this, requestId](const ::LLMCore::RequestID &clientId, const QString &fullText) { + emit fullResponseReceived(requestId, fullText); + m_providerToClientIds.remove(requestId); + m_clientToProviderIds.remove(clientId); + m_awaitingContinuation.remove(requestId); + }; + + callbacks.onFailed + = [this, requestId](const ::LLMCore::RequestID &clientId, const QString &error) { + emit requestFailed(requestId, error); + m_providerToClientIds.remove(requestId); + m_clientToProviderIds.remove(clientId); + m_awaitingContinuation.remove(requestId); + }; + + callbacks.onThinkingBlock = [this, requestId](const ::LLMCore::RequestID &, + const QString &thinking, + const QString &signature) { + if (m_awaitingContinuation.remove(requestId)) + emit continuationStarted(requestId); + if (thinking.isEmpty()) + emit redactedThinkingBlockReceived(requestId, signature); + else + emit thinkingBlockReceived(requestId, thinking, signature); + }; + + callbacks.onToolStarted = [this, requestId](const ::LLMCore::RequestID &, + const QString &toolId, + const QString &toolName) { + emit toolExecutionStarted(requestId, toolId, toolName); + m_awaitingContinuation.insert(requestId); + }; + + callbacks.onToolResult = [this, requestId](const ::LLMCore::RequestID &, + const QString &toolId, + const QString &toolName, + const QString &result) { + emit toolExecutionCompleted(requestId, toolId, toolName, result); + }; + + auto clientId = m_client->sendMessage(payload, callbacks); + m_providerToClientIds[requestId] = clientId; + m_clientToProviderIds[clientId] = requestId; + + LOG_MESSAGE(QString("OpenAIResponsesProvider: Sending request %1 (client: %2) to %3") + .arg(requestId, clientId, url.toString())); } bool OpenAIResponsesProvider::supportsTools() const @@ -292,364 +309,14 @@ bool OpenAIResponsesProvider::supportThinking() const void OpenAIResponsesProvider::cancelRequest(const PluginLLMCore::RequestID &requestId) { - PluginLLMCore::Provider::cancelRequest(requestId); - cleanupRequest(requestId); -} + LOG_MESSAGE(QString("OpenAIResponsesProvider: Cancelling request %1").arg(requestId)); -void OpenAIResponsesProvider::onDataReceived( - const QodeAssist::PluginLLMCore::RequestID &requestId, const QByteArray &data) -{ - PluginLLMCore::DataBuffers &buffers = m_dataBuffers[requestId]; - const QStringList lines = buffers.rawStreamBuffer.processData(data); - - QString currentEventType; - - for (const QString &line : lines) { - const QString trimmedLine = line.trimmed(); - if (trimmedLine.isEmpty()) { - continue; - } - - if (line == "data: [DONE]") { - continue; - } - - if (line.startsWith("event: ")) { - currentEventType = line.mid(7).trimmed(); - continue; - } - - QString dataLine = line; - if (line.startsWith("data: ")) { - dataLine = line.mid(6); - } - - const QJsonDocument doc = QJsonDocument::fromJson(dataLine.toUtf8()); - if (doc.isObject()) { - const QJsonObject obj = doc.object(); - processStreamEvent(requestId, currentEventType, obj); - } + if (m_providerToClientIds.contains(requestId)) { + auto clientId = m_providerToClientIds.take(requestId); + m_clientToProviderIds.remove(clientId); + m_client->cancelRequest(clientId); } -} - -void OpenAIResponsesProvider::onRequestFinished( - const QodeAssist::PluginLLMCore::RequestID &requestId, std::optional error) -{ - if (error) { - LOG_MESSAGE(QString("OpenAIResponses request %1 failed: %2").arg(requestId, *error)); - emit requestFailed(requestId, *error); - cleanupRequest(requestId); - return; - } - - if (m_messages.contains(requestId)) { - OpenAIResponsesMessage *message = m_messages[requestId]; - if (message->state() == PluginLLMCore::MessageState::RequiresToolExecution) { - return; - } - } - - if (m_dataBuffers.contains(requestId)) { - const PluginLLMCore::DataBuffers &buffers = m_dataBuffers[requestId]; - if (!buffers.responseContent.isEmpty()) { - emit fullResponseReceived(requestId, buffers.responseContent); - } else { - LOG_MESSAGE(QString("WARNING: OpenAIResponses - Response content is empty for %1, " - "emitting empty response") - .arg(requestId)); - emit fullResponseReceived(requestId, ""); - } - } else { - LOG_MESSAGE( - QString("WARNING: OpenAIResponses - No data buffer found for %1").arg(requestId)); - } - - cleanupRequest(requestId); -} - -void OpenAIResponsesProvider::processStreamEvent( - const QString &requestId, const QString &eventType, const QJsonObject &data) -{ - OpenAIResponsesMessage *message = m_messages.value(requestId); - if (!message) { - message = new OpenAIResponsesMessage(this); - m_messages[requestId] = message; - - if (m_dataBuffers.contains(requestId)) { - emit continuationStarted(requestId); - } - } else if ( - m_dataBuffers.contains(requestId) - && message->state() == PluginLLMCore::MessageState::RequiresToolExecution) { - message->startNewContinuation(); - emit continuationStarted(requestId); - } - - if (eventType == "response.content_part.added") { - } else if (eventType == "response.output_text.delta") { - const QString delta = data["delta"].toString(); - if (!delta.isEmpty()) { - m_dataBuffers[requestId].responseContent += delta; - emit partialResponseReceived(requestId, delta); - } - } else if (eventType == "response.output_text.done") { - const QString fullText = data["text"].toString(); - if (!fullText.isEmpty()) { - m_dataBuffers[requestId].responseContent = fullText; - } - } else if (eventType == "response.content_part.done") { - } else if (eventType == "response.output_item.added") { - using namespace QodeAssist::OpenAIResponses; - const QJsonObject item = data["item"].toObject(); - OutputItem outputItem = OutputItem::fromJson(item); - - if (const auto *functionCall = outputItem.asFunctionCall()) { - if (!functionCall->callId.isEmpty() && !functionCall->name.isEmpty()) { - if (!m_itemIdToCallId.contains(requestId)) { - m_itemIdToCallId[requestId] = QHash(); - } - m_itemIdToCallId[requestId][functionCall->id] = functionCall->callId; - message->handleToolCallStart(functionCall->callId, functionCall->name); - } - } else if (const auto *reasoning = outputItem.asReasoning()) { - if (!reasoning->id.isEmpty()) { - message->handleReasoningStart(reasoning->id); - } - } - } else if (eventType == "response.reasoning_content.delta") { - const QString itemId = data["item_id"].toString(); - const QString delta = data["delta"].toString(); - if (!itemId.isEmpty() && !delta.isEmpty()) { - message->handleReasoningDelta(itemId, delta); - } - } else if (eventType == "response.reasoning_content.done") { - const QString itemId = data["item_id"].toString(); - if (!itemId.isEmpty()) { - message->handleReasoningComplete(itemId); - emitPendingThinkingBlocks(requestId); - } - } else if (eventType == "response.function_call_arguments.delta") { - const QString itemId = data["item_id"].toString(); - const QString delta = data["delta"].toString(); - if (!itemId.isEmpty() && !delta.isEmpty()) { - const QString callId = m_itemIdToCallId.value(requestId).value(itemId); - if (!callId.isEmpty()) { - message->handleToolCallDelta(callId, delta); - } else { - LOG_MESSAGE(QString("ERROR: No call_id mapping found for item_id: %1").arg(itemId)); - } - } - } else if ( - eventType == "response.function_call_arguments.done" - || eventType == "response.output_item.done") { - const QString itemId = data["item_id"].toString(); - const QJsonObject item = data["item"].toObject(); - - if (!item.isEmpty() && item["type"].toString() == "reasoning") { - using namespace QodeAssist::OpenAIResponses; - - const QString finalItemId = itemId.isEmpty() ? item["id"].toString() : itemId; - - ReasoningOutput reasoningOutput = ReasoningOutput::fromJson(item); - QString reasoningText; - - if (!reasoningOutput.summaryText.isEmpty()) { - reasoningText = reasoningOutput.summaryText; - } else if (!reasoningOutput.contentTexts.isEmpty()) { - reasoningText = reasoningOutput.contentTexts.join("\n"); - } - - if (reasoningText.isEmpty()) { - reasoningText = QString( - "[Reasoning process completed, but detailed thinking is not available in " - "streaming mode. The model has processed your request with extended reasoning.]"); - } - - if (!finalItemId.isEmpty()) { - message->handleReasoningDelta(finalItemId, reasoningText); - message->handleReasoningComplete(finalItemId); - emitPendingThinkingBlocks(requestId); - } - } else if (item.isEmpty() && !itemId.isEmpty()) { - const QString callId = m_itemIdToCallId.value(requestId).value(itemId); - if (!callId.isEmpty()) { - message->handleToolCallComplete(callId); - } else { - LOG_MESSAGE( - QString("ERROR: OpenAIResponses - No call_id mapping found for item_id: %1") - .arg(itemId)); - } - } else if (!item.isEmpty() && item["type"].toString() == "function_call") { - const QString callId = item["call_id"].toString(); - if (!callId.isEmpty()) { - message->handleToolCallComplete(callId); - } else { - LOG_MESSAGE( - QString("ERROR: OpenAIResponses - Function call done but call_id is empty")); - } - } - } else if (eventType == "response.created") { - } else if (eventType == "response.in_progress") { - } else if (eventType == "response.completed") { - using namespace QodeAssist::OpenAIResponses; - const QJsonObject responseObj = data["response"].toObject(); - Response response = Response::fromJson(responseObj); - - const QString statusStr = responseObj["status"].toString(); - - if (m_dataBuffers[requestId].responseContent.isEmpty()) { - const QString aggregatedText = response.getAggregatedText(); - if (!aggregatedText.isEmpty()) { - m_dataBuffers[requestId].responseContent = aggregatedText; - } - } - - message->handleStatus(statusStr); - handleMessageComplete(requestId); - } else if (eventType == "response.incomplete") { - using namespace QodeAssist::OpenAIResponses; - const QJsonObject responseObj = data["response"].toObject(); - - if (!responseObj.isEmpty()) { - Response response = Response::fromJson(responseObj); - const QString statusStr = responseObj["status"].toString(); - - if (m_dataBuffers[requestId].responseContent.isEmpty()) { - const QString aggregatedText = response.getAggregatedText(); - if (!aggregatedText.isEmpty()) { - m_dataBuffers[requestId].responseContent = aggregatedText; - } - } - - message->handleStatus(statusStr); - } else { - message->handleStatus("incomplete"); - } - - handleMessageComplete(requestId); - } else if (!eventType.isEmpty()) { - LOG_MESSAGE(QString("WARNING: OpenAIResponses - Unhandled event type '%1' for request %2\nData: %3") - .arg(eventType) - .arg(requestId) - .arg(QString::fromUtf8(QJsonDocument(data).toJson(QJsonDocument::Compact)))); - } -} - -void OpenAIResponsesProvider::emitPendingThinkingBlocks(const QString &requestId) -{ - if (!m_messages.contains(requestId)) { - return; - } - - OpenAIResponsesMessage *message = m_messages[requestId]; - const auto thinkingBlocks = message->getCurrentThinkingContent(); - - if (thinkingBlocks.isEmpty()) { - return; - } - - const int alreadyEmitted = m_emittedThinkingBlocksCount.value(requestId, 0); - const int totalBlocks = thinkingBlocks.size(); - - for (int i = alreadyEmitted; i < totalBlocks; ++i) { - const auto *thinkingContent = thinkingBlocks[i]; - - if (thinkingContent->thinking().trimmed().isEmpty()) { - continue; - } - - emit thinkingBlockReceived( - requestId, thinkingContent->thinking(), thinkingContent->signature()); - } - - m_emittedThinkingBlocksCount[requestId] = totalBlocks; -} - -void OpenAIResponsesProvider::handleMessageComplete(const QString &requestId) -{ - if (!m_messages.contains(requestId)) { - return; - } - - OpenAIResponsesMessage *message = m_messages[requestId]; - - emitPendingThinkingBlocks(requestId); - - if (message->state() == PluginLLMCore::MessageState::RequiresToolExecution) { - const auto toolUseContent = message->getCurrentToolUseContent(); - - if (toolUseContent.isEmpty()) { - return; - } - - for (const auto *toolContent : toolUseContent) { - const auto toolStringName = m_client->tools()->displayName( - toolContent->name()); - emit toolExecutionStarted(requestId, toolContent->id(), toolStringName); - m_client->tools()->executeToolCall( - requestId, toolContent->id(), toolContent->name(), toolContent->input()); - } - } -} - -void OpenAIResponsesProvider::onToolExecutionComplete( - const QString &requestId, const QHash &toolResults) -{ - if (!m_messages.contains(requestId) || !m_requestUrls.contains(requestId)) { - LOG_MESSAGE(QString("ERROR: OpenAIResponses - Missing data for continuation request %1") - .arg(requestId)); - cleanupRequest(requestId); - return; - } - - OpenAIResponsesMessage *message = m_messages[requestId]; - const auto toolContent = message->getCurrentToolUseContent(); - - for (auto it = toolResults.constBegin(); it != toolResults.constEnd(); ++it) { - for (const auto *tool : toolContent) { - if (tool->id() == it.key()) { - const auto toolStringName = m_client->tools()->displayName( - tool->name()); - emit toolExecutionCompleted( - requestId, tool->id(), toolStringName, toolResults[tool->id()]); - break; - } - } - } - - QJsonObject continuationRequest = m_originalRequests[requestId]; - QJsonArray input = continuationRequest["input"].toArray(); - - const QList assistantItems = message->toItemsFormat(); - for (const QJsonObject &item : assistantItems) { - input.append(item); - } - - const QJsonArray toolResultItems = message->createToolResultItems(toolResults); - for (const QJsonValue &item : toolResultItems) { - input.append(item); - } - - continuationRequest["input"] = input; - - m_dataBuffers[requestId].responseContent.clear(); - - sendRequest(requestId, m_requestUrls[requestId], continuationRequest); -} - -void OpenAIResponsesProvider::cleanupRequest(const PluginLLMCore::RequestID &requestId) -{ - if (m_messages.contains(requestId)) { - OpenAIResponsesMessage *message = m_messages.take(requestId); - message->deleteLater(); - } - - m_dataBuffers.remove(requestId); - m_requestUrls.remove(requestId); - m_originalRequests.remove(requestId); - m_itemIdToCallId.remove(requestId); - m_emittedThinkingBlocksCount.remove(requestId); - m_client->tools()->cleanupRequest(requestId); + m_awaitingContinuation.remove(requestId); } ::LLMCore::ToolsManager *OpenAIResponsesProvider::toolsManager() const diff --git a/providers/OpenAIResponsesProvider.hpp b/providers/OpenAIResponsesProvider.hpp index 1f63e77..1e27a94 100644 --- a/providers/OpenAIResponsesProvider.hpp +++ b/providers/OpenAIResponsesProvider.hpp @@ -1,4 +1,4 @@ -/* +/* * Copyright (C) 2024-2025 Petr Mironychev * * This file is part of QodeAssist. @@ -19,7 +19,8 @@ #pragma once -#include "OpenAIResponsesMessage.hpp" +#include + #include #include @@ -59,30 +60,11 @@ public: ::LLMCore::ToolsManager *toolsManager() const override; -public slots: - void onDataReceived( - const QodeAssist::PluginLLMCore::RequestID &requestId, const QByteArray &data) override; - void onRequestFinished( - const QodeAssist::PluginLLMCore::RequestID &requestId, - std::optional error) override; - -private slots: - void onToolExecutionComplete( - const QString &requestId, const QHash &toolResults); - private: - void processStreamEvent(const QString &requestId, const QString &eventType, const QJsonObject &data); - void emitPendingThinkingBlocks(const QString &requestId); - void handleMessageComplete(const QString &requestId); - void cleanupRequest(const PluginLLMCore::RequestID &requestId); - - QHash m_messages; - QHash m_requestUrls; - QHash m_originalRequests; - QHash> m_itemIdToCallId; - QHash m_emittedThinkingBlocksCount; ::LLMCore::OpenAIResponsesClient *m_client; + QHash m_providerToClientIds; + QHash<::LLMCore::RequestID, PluginLLMCore::RequestID> m_clientToProviderIds; + QSet m_awaitingContinuation; }; } // namespace QodeAssist::Providers -