From 30885c0373d7844a834704c792d7c339b2f5268e Mon Sep 17 00:00:00 2001 From: Petr Mironychev <9195189+Palm1r@users.noreply.github.com> Date: Thu, 13 Nov 2025 23:52:38 +0100 Subject: [PATCH] feat: Add google provider thinking mode (#255) fix: add signature --- providers/GoogleAIProvider.cpp | 161 ++++++++++++++++++++++++++++++--- providers/GoogleAIProvider.hpp | 4 + providers/GoogleMessage.cpp | 89 ++++++++++++++++-- providers/GoogleMessage.hpp | 6 ++ 4 files changed, 236 insertions(+), 24 deletions(-) diff --git a/providers/GoogleAIProvider.cpp b/providers/GoogleAIProvider.cpp index e062d84..5b4268d 100644 --- a/providers/GoogleAIProvider.cpp +++ b/providers/GoogleAIProvider.cpp @@ -100,7 +100,49 @@ void GoogleAIProvider::prepareRequest( if (type == LLMCore::RequestType::CodeCompletion) { applyModelParams(Settings::codeCompletionSettings()); } else { - applyModelParams(Settings::chatAssistantSettings()); + const auto &chatSettings = Settings::chatAssistantSettings(); + + if (chatSettings.enableThinkingMode()) { + QJsonObject generationConfig; + generationConfig["maxOutputTokens"] = chatSettings.thinkingMaxTokens(); + + if (chatSettings.useTopP()) + generationConfig["topP"] = chatSettings.topP(); + if (chatSettings.useTopK()) + generationConfig["topK"] = chatSettings.topK(); + + // Set temperature to 1.0 for thinking mode + generationConfig["temperature"] = 1.0; + + // Add thinkingConfig + QJsonObject thinkingConfig; + int budgetTokens = chatSettings.thinkingBudgetTokens(); + + // Dynamic thinking: -1 (let model decide) + // Disabled: 0 (no thinking) + // Custom budget: positive integer + if (budgetTokens == -1) { + // Dynamic thinking - omit budget to let model decide + thinkingConfig["includeThoughts"] = true; + } else if (budgetTokens == 0) { + // Disabled thinking + thinkingConfig["thinkingBudget"] = 0; + thinkingConfig["includeThoughts"] = false; + } else { + // Custom budget + thinkingConfig["thinkingBudget"] = budgetTokens; + thinkingConfig["includeThoughts"] = true; + } + + generationConfig["thinkingConfig"] = thinkingConfig; + request["generationConfig"] = generationConfig; + + LOG_MESSAGE(QString("Google AI thinking mode enabled: budget=%1 tokens, maxTokens=%2") + .arg(budgetTokens) + .arg(chatSettings.thinkingMaxTokens())); + } else { + applyModelParams(chatSettings); + } } if (isToolsEnabled) { @@ -164,7 +206,13 @@ QList GoogleAIProvider::validateRequest( {"contents", QJsonArray{}}, {"system_instruction", QJsonArray{}}, {"generationConfig", - QJsonObject{{"temperature", {}}, {"maxOutputTokens", {}}, {"topP", {}}, {"topK", {}}}}, + QJsonObject{ + {"temperature", {}}, + {"maxOutputTokens", {}}, + {"topP", {}}, + {"topK", {}}, + {"thinkingConfig", + QJsonObject{{"thinkingBudget", {}}, {"includeThoughts", {}}}}}}, {"safetySettings", QJsonArray{}}, {"tools", QJsonArray{}}}; @@ -219,6 +267,11 @@ bool GoogleAIProvider::supportsTools() const return true; } +bool GoogleAIProvider::supportThinking() const +{ + return true; +} + void GoogleAIProvider::cancelRequest(const LLMCore::RequestID &requestId) { LOG_MESSAGE(QString("GoogleAIProvider: Cancelling request %1").arg(requestId)); @@ -277,8 +330,18 @@ void GoogleAIProvider::onRequestFinished( return; } + if (m_failedRequests.contains(requestId)) { + cleanupRequest(requestId); + return; + } + + emitPendingThinkingBlocks(requestId); + if (m_messages.contains(requestId)) { GoogleMessage *message = m_messages[requestId]; + + handleMessageComplete(requestId); + if (message->state() == LLMCore::MessageState::RequiresToolExecution) { LOG_MESSAGE(QString("Waiting for tools to complete for %1").arg(requestId)); m_dataBuffers.remove(requestId); @@ -289,9 +352,12 @@ void GoogleAIProvider::onRequestFinished( if (m_dataBuffers.contains(requestId)) { const LLMCore::DataBuffers &buffers = m_dataBuffers[requestId]; if (!buffers.responseContent.isEmpty()) { - LOG_MESSAGE(QString("Emitting full response for %1").arg(requestId)); emit fullResponseReceived(requestId, buffers.responseContent); + } else { + emit fullResponseReceived(requestId, QString()); } + } else { + emit fullResponseReceived(requestId, QString()); } cleanupRequest(requestId); @@ -306,8 +372,6 @@ void GoogleAIProvider::onToolExecutionComplete( return; } - LOG_MESSAGE(QString("Tool execution complete for Google AI request %1").arg(requestId)); - for (auto it = toolResults.begin(); it != toolResults.end(); ++it) { GoogleMessage *message = m_messages[requestId]; auto toolContent = message->getCurrentToolUseContent(); @@ -334,10 +398,6 @@ void GoogleAIProvider::onToolExecutionComplete( continuationRequest["contents"] = contents; - LOG_MESSAGE(QString("Sending continuation request for %1 with %2 tool results") - .arg(requestId) - .arg(toolResults.size())); - sendRequest(requestId, m_requestUrls[requestId], continuationRequest); } @@ -361,6 +421,7 @@ void GoogleAIProvider::processStreamChunk(const QString &requestId, const QJsonO m_dataBuffers.contains(requestId) && message->state() == LLMCore::MessageState::RequiresToolExecution) { message->startNewContinuation(); + m_emittedThinkingBlocksCount[requestId] = 0; LOG_MESSAGE(QString("Cleared message state for continuation request %1").arg(requestId)); } @@ -377,12 +438,34 @@ void GoogleAIProvider::processStreamChunk(const QString &requestId, const QJsonO if (partObj.contains("text")) { QString text = partObj["text"].toString(); - message->handleContentDelta(text); + bool isThought = partObj.value("thought").toBool(false); + + if (isThought) { + message->handleThoughtDelta(text); + + if (partObj.contains("signature")) { + QString signature = partObj["signature"].toString(); + message->handleThoughtSignature(signature); + } + } else { + emitPendingThinkingBlocks(requestId); + + message->handleContentDelta(text); - LLMCore::DataBuffers &buffers = m_dataBuffers[requestId]; - buffers.responseContent += text; - emit partialResponseReceived(requestId, text); - } else if (partObj.contains("functionCall")) { + LLMCore::DataBuffers &buffers = m_dataBuffers[requestId]; + buffers.responseContent += text; + emit partialResponseReceived(requestId, text); + } + } + + if (partObj.contains("thoughtSignature")) { + QString signature = partObj["thoughtSignature"].toString(); + message->handleThoughtSignature(signature); + } + + if (partObj.contains("functionCall")) { + emitPendingThinkingBlocks(requestId); + QJsonObject functionCall = partObj["functionCall"].toObject(); QString name = functionCall["name"].toString(); QJsonObject args = functionCall["args"].toObject(); @@ -399,9 +482,55 @@ void GoogleAIProvider::processStreamChunk(const QString &requestId, const QJsonO if (candidateObj.contains("finishReason")) { QString finishReason = candidateObj["finishReason"].toString(); message->handleFinishReason(finishReason); - handleMessageComplete(requestId); + + if (message->isErrorFinishReason()) { + QString errorMessage = message->getErrorMessage(); + LOG_MESSAGE(QString("Google AI error: %1").arg(errorMessage)); + m_failedRequests.insert(requestId); + emit requestFailed(requestId, errorMessage); + return; + } } } + + if (chunk.contains("usageMetadata")) { + QJsonObject usageMetadata = chunk["usageMetadata"].toObject(); + int thoughtsTokenCount = usageMetadata.value("thoughtsTokenCount").toInt(0); + int candidatesTokenCount = usageMetadata.value("candidatesTokenCount").toInt(0); + int totalTokenCount = usageMetadata.value("totalTokenCount").toInt(0); + + if (totalTokenCount > 0) { + LOG_MESSAGE(QString("Google AI tokens: %1 (thoughts: %2, output: %3)") + .arg(totalTokenCount) + .arg(thoughtsTokenCount) + .arg(candidatesTokenCount)); + } + } +} + +void GoogleAIProvider::emitPendingThinkingBlocks(const QString &requestId) +{ + if (!m_messages.contains(requestId)) + return; + + GoogleMessage *message = m_messages[requestId]; + auto thinkingBlocks = message->getCurrentThinkingContent(); + + if (thinkingBlocks.isEmpty()) + return; + + int alreadyEmitted = m_emittedThinkingBlocksCount.value(requestId, 0); + int totalBlocks = thinkingBlocks.size(); + + for (int i = alreadyEmitted; i < totalBlocks; ++i) { + auto thinkingContent = thinkingBlocks[i]; + emit thinkingBlockReceived( + requestId, + thinkingContent->thinking(), + thinkingContent->signature()); + } + + m_emittedThinkingBlocksCount[requestId] = totalBlocks; } void GoogleAIProvider::handleMessageComplete(const QString &requestId) @@ -445,6 +574,8 @@ void GoogleAIProvider::cleanupRequest(const LLMCore::RequestID &requestId) m_dataBuffers.remove(requestId); m_requestUrls.remove(requestId); m_originalRequests.remove(requestId); + m_emittedThinkingBlocksCount.remove(requestId); + m_failedRequests.remove(requestId); m_toolsManager->cleanupRequest(requestId); } diff --git a/providers/GoogleAIProvider.hpp b/providers/GoogleAIProvider.hpp index 5893342..851126e 100644 --- a/providers/GoogleAIProvider.hpp +++ b/providers/GoogleAIProvider.hpp @@ -52,6 +52,7 @@ public: const LLMCore::RequestID &requestId, const QUrl &url, const QJsonObject &payload) override; bool supportsTools() const override; + bool supportThinking() const override; void cancelRequest(const LLMCore::RequestID &requestId) override; public slots: @@ -69,11 +70,14 @@ private slots: private: void processStreamChunk(const QString &requestId, const QJsonObject &chunk); void handleMessageComplete(const QString &requestId); + void emitPendingThinkingBlocks(const QString &requestId); void cleanupRequest(const LLMCore::RequestID &requestId); QHash m_messages; QHash m_requestUrls; QHash m_originalRequests; + QHash m_emittedThinkingBlocksCount; + QSet m_failedRequests; Tools::ToolsManager *m_toolsManager; }; diff --git a/providers/GoogleMessage.cpp b/providers/GoogleMessage.cpp index 1d5117d..c2bfeeb 100644 --- a/providers/GoogleMessage.cpp +++ b/providers/GoogleMessage.cpp @@ -43,12 +43,38 @@ void GoogleMessage::handleContentDelta(const QString &text) } } +void GoogleMessage::handleThoughtDelta(const QString &text) +{ + if (m_currentBlocks.isEmpty() || !qobject_cast(m_currentBlocks.last())) { + auto thinkingContent = new LLMCore::ThinkingContent(); + thinkingContent->setParent(this); + m_currentBlocks.append(thinkingContent); + } + + if (auto thinkingContent = qobject_cast(m_currentBlocks.last())) { + thinkingContent->appendThinking(text); + } +} + +void GoogleMessage::handleThoughtSignature(const QString &signature) +{ + for (int i = m_currentBlocks.size() - 1; i >= 0; --i) { + if (auto thinkingContent = qobject_cast(m_currentBlocks[i])) { + thinkingContent->setSignature(signature); + return; + } + } + + auto thinkingContent = new LLMCore::ThinkingContent(); + thinkingContent->setParent(this); + thinkingContent->setSignature(signature); + m_currentBlocks.append(thinkingContent); +} + void GoogleMessage::handleFunctionCallStart(const QString &name) { m_currentFunctionName = name; m_pendingFunctionArgs.clear(); - - LOG_MESSAGE(QString("Google: Starting function call: %1").arg(name)); } void GoogleMessage::handleFunctionCallArgsDelta(const QString &argsJson) @@ -75,10 +101,6 @@ void GoogleMessage::handleFunctionCallComplete() toolContent->setParent(this); m_currentBlocks.append(toolContent); - LOG_MESSAGE(QString("Google: Completed function call: name=%1, args=%2") - .arg(m_currentFunctionName) - .arg(QString::fromUtf8(QJsonDocument(args).toJson(QJsonDocument::Compact)))); - m_currentFunctionName.clear(); m_pendingFunctionArgs.clear(); } @@ -87,9 +109,6 @@ void GoogleMessage::handleFinishReason(const QString &reason) { m_finishReason = reason; updateStateFromFinishReason(); - - LOG_MESSAGE( - QString("Google: Finish reason: %1, state: %2").arg(reason).arg(static_cast(m_state))); } QJsonObject GoogleMessage::toProviderFormat() const @@ -110,6 +129,19 @@ QJsonObject GoogleMessage::toProviderFormat() const functionCall["name"] = tool->name(); functionCall["args"] = tool->input(); parts.append(QJsonObject{{"functionCall", functionCall}}); + } else if (auto thinking = qobject_cast(block)) { + // Include thinking blocks with their text + QJsonObject thinkingPart; + thinkingPart["text"] = thinking->thinking(); + thinkingPart["thought"] = true; + parts.append(thinkingPart); + + // If there's a signature, add it as a separate part + if (!thinking->signature().isEmpty()) { + QJsonObject signaturePart; + signaturePart["thoughtSignature"] = thinking->signature(); + parts.append(signaturePart); + } } } @@ -148,6 +180,17 @@ QList GoogleMessage::getCurrentToolUseContent() const return toolBlocks; } +QList GoogleMessage::getCurrentThinkingContent() const +{ + QList thinkingBlocks; + for (auto block : m_currentBlocks) { + if (auto thinkingContent = qobject_cast(block)) { + thinkingBlocks.append(thinkingContent); + } + } + return thinkingBlocks; +} + void GoogleMessage::startNewContinuation() { LOG_MESSAGE(QString("GoogleMessage: Starting new continuation")); @@ -159,6 +202,34 @@ void GoogleMessage::startNewContinuation() m_state = LLMCore::MessageState::Building; } +bool GoogleMessage::isErrorFinishReason() const +{ + return m_finishReason == "SAFETY" + || m_finishReason == "RECITATION" + || m_finishReason == "MALFORMED_FUNCTION_CALL" + || m_finishReason == "PROHIBITED_CONTENT" + || m_finishReason == "SPII" + || m_finishReason == "OTHER"; +} + +QString GoogleMessage::getErrorMessage() const +{ + if (m_finishReason == "SAFETY") { + return "Response blocked by safety filters"; + } else if (m_finishReason == "RECITATION") { + return "Response blocked due to recitation of copyrighted content"; + } else if (m_finishReason == "MALFORMED_FUNCTION_CALL") { + return "Model attempted to call a function with malformed arguments. Please try rephrasing your request or disabling tools."; + } else if (m_finishReason == "PROHIBITED_CONTENT") { + return "Response blocked due to prohibited content"; + } else if (m_finishReason == "SPII") { + return "Response blocked due to sensitive personally identifiable information"; + } else if (m_finishReason == "OTHER") { + return "Request failed due to an unknown reason"; + } + return QString(); +} + void GoogleMessage::updateStateFromFinishReason() { if (m_finishReason == "STOP" || m_finishReason == "MAX_TOKENS") { diff --git a/providers/GoogleMessage.hpp b/providers/GoogleMessage.hpp index 263447b..036e8d1 100644 --- a/providers/GoogleMessage.hpp +++ b/providers/GoogleMessage.hpp @@ -35,6 +35,8 @@ public: explicit GoogleMessage(QObject *parent = nullptr); void handleContentDelta(const QString &text); + void handleThoughtDelta(const QString &text); + void handleThoughtSignature(const QString &signature); void handleFunctionCallStart(const QString &name); void handleFunctionCallArgsDelta(const QString &argsJson); void handleFunctionCallComplete(); @@ -44,9 +46,13 @@ public: QJsonArray createToolResultParts(const QHash &toolResults) const; QList getCurrentToolUseContent() const; + QList getCurrentThinkingContent() const; QList currentBlocks() const { return m_currentBlocks; } LLMCore::MessageState state() const { return m_state; } + QString finishReason() const { return m_finishReason; } + bool isErrorFinishReason() const; + QString getErrorMessage() const; void startNewContinuation(); private: