feat: Add google provider thinking mode (#255)

fix: add signature
This commit is contained in:
Petr Mironychev
2025-11-13 23:52:38 +01:00
committed by GitHub
parent 5e580b8792
commit 30885c0373
4 changed files with 236 additions and 24 deletions

View File

@ -100,7 +100,49 @@ void GoogleAIProvider::prepareRequest(
if (type == LLMCore::RequestType::CodeCompletion) {
applyModelParams(Settings::codeCompletionSettings());
} else {
applyModelParams(Settings::chatAssistantSettings());
const auto &chatSettings = Settings::chatAssistantSettings();
if (chatSettings.enableThinkingMode()) {
QJsonObject generationConfig;
generationConfig["maxOutputTokens"] = chatSettings.thinkingMaxTokens();
if (chatSettings.useTopP())
generationConfig["topP"] = chatSettings.topP();
if (chatSettings.useTopK())
generationConfig["topK"] = chatSettings.topK();
// Set temperature to 1.0 for thinking mode
generationConfig["temperature"] = 1.0;
// Add thinkingConfig
QJsonObject thinkingConfig;
int budgetTokens = chatSettings.thinkingBudgetTokens();
// Dynamic thinking: -1 (let model decide)
// Disabled: 0 (no thinking)
// Custom budget: positive integer
if (budgetTokens == -1) {
// Dynamic thinking - omit budget to let model decide
thinkingConfig["includeThoughts"] = true;
} else if (budgetTokens == 0) {
// Disabled thinking
thinkingConfig["thinkingBudget"] = 0;
thinkingConfig["includeThoughts"] = false;
} else {
// Custom budget
thinkingConfig["thinkingBudget"] = budgetTokens;
thinkingConfig["includeThoughts"] = true;
}
generationConfig["thinkingConfig"] = thinkingConfig;
request["generationConfig"] = generationConfig;
LOG_MESSAGE(QString("Google AI thinking mode enabled: budget=%1 tokens, maxTokens=%2")
.arg(budgetTokens)
.arg(chatSettings.thinkingMaxTokens()));
} else {
applyModelParams(chatSettings);
}
}
if (isToolsEnabled) {
@ -164,7 +206,13 @@ QList<QString> GoogleAIProvider::validateRequest(
{"contents", QJsonArray{}},
{"system_instruction", QJsonArray{}},
{"generationConfig",
QJsonObject{{"temperature", {}}, {"maxOutputTokens", {}}, {"topP", {}}, {"topK", {}}}},
QJsonObject{
{"temperature", {}},
{"maxOutputTokens", {}},
{"topP", {}},
{"topK", {}},
{"thinkingConfig",
QJsonObject{{"thinkingBudget", {}}, {"includeThoughts", {}}}}}},
{"safetySettings", QJsonArray{}},
{"tools", QJsonArray{}}};
@ -219,6 +267,11 @@ bool GoogleAIProvider::supportsTools() const
return true;
}
bool GoogleAIProvider::supportThinking() const
{
return true;
}
void GoogleAIProvider::cancelRequest(const LLMCore::RequestID &requestId)
{
LOG_MESSAGE(QString("GoogleAIProvider: Cancelling request %1").arg(requestId));
@ -277,8 +330,18 @@ void GoogleAIProvider::onRequestFinished(
return;
}
if (m_failedRequests.contains(requestId)) {
cleanupRequest(requestId);
return;
}
emitPendingThinkingBlocks(requestId);
if (m_messages.contains(requestId)) {
GoogleMessage *message = m_messages[requestId];
handleMessageComplete(requestId);
if (message->state() == LLMCore::MessageState::RequiresToolExecution) {
LOG_MESSAGE(QString("Waiting for tools to complete for %1").arg(requestId));
m_dataBuffers.remove(requestId);
@ -289,9 +352,12 @@ void GoogleAIProvider::onRequestFinished(
if (m_dataBuffers.contains(requestId)) {
const LLMCore::DataBuffers &buffers = m_dataBuffers[requestId];
if (!buffers.responseContent.isEmpty()) {
LOG_MESSAGE(QString("Emitting full response for %1").arg(requestId));
emit fullResponseReceived(requestId, buffers.responseContent);
} else {
emit fullResponseReceived(requestId, QString());
}
} else {
emit fullResponseReceived(requestId, QString());
}
cleanupRequest(requestId);
@ -306,8 +372,6 @@ void GoogleAIProvider::onToolExecutionComplete(
return;
}
LOG_MESSAGE(QString("Tool execution complete for Google AI request %1").arg(requestId));
for (auto it = toolResults.begin(); it != toolResults.end(); ++it) {
GoogleMessage *message = m_messages[requestId];
auto toolContent = message->getCurrentToolUseContent();
@ -334,10 +398,6 @@ void GoogleAIProvider::onToolExecutionComplete(
continuationRequest["contents"] = contents;
LOG_MESSAGE(QString("Sending continuation request for %1 with %2 tool results")
.arg(requestId)
.arg(toolResults.size()));
sendRequest(requestId, m_requestUrls[requestId], continuationRequest);
}
@ -361,6 +421,7 @@ void GoogleAIProvider::processStreamChunk(const QString &requestId, const QJsonO
m_dataBuffers.contains(requestId)
&& message->state() == LLMCore::MessageState::RequiresToolExecution) {
message->startNewContinuation();
m_emittedThinkingBlocksCount[requestId] = 0;
LOG_MESSAGE(QString("Cleared message state for continuation request %1").arg(requestId));
}
@ -377,12 +438,34 @@ void GoogleAIProvider::processStreamChunk(const QString &requestId, const QJsonO
if (partObj.contains("text")) {
QString text = partObj["text"].toString();
message->handleContentDelta(text);
bool isThought = partObj.value("thought").toBool(false);
if (isThought) {
message->handleThoughtDelta(text);
if (partObj.contains("signature")) {
QString signature = partObj["signature"].toString();
message->handleThoughtSignature(signature);
}
} else {
emitPendingThinkingBlocks(requestId);
message->handleContentDelta(text);
LLMCore::DataBuffers &buffers = m_dataBuffers[requestId];
buffers.responseContent += text;
emit partialResponseReceived(requestId, text);
} else if (partObj.contains("functionCall")) {
LLMCore::DataBuffers &buffers = m_dataBuffers[requestId];
buffers.responseContent += text;
emit partialResponseReceived(requestId, text);
}
}
if (partObj.contains("thoughtSignature")) {
QString signature = partObj["thoughtSignature"].toString();
message->handleThoughtSignature(signature);
}
if (partObj.contains("functionCall")) {
emitPendingThinkingBlocks(requestId);
QJsonObject functionCall = partObj["functionCall"].toObject();
QString name = functionCall["name"].toString();
QJsonObject args = functionCall["args"].toObject();
@ -399,9 +482,55 @@ void GoogleAIProvider::processStreamChunk(const QString &requestId, const QJsonO
if (candidateObj.contains("finishReason")) {
QString finishReason = candidateObj["finishReason"].toString();
message->handleFinishReason(finishReason);
handleMessageComplete(requestId);
if (message->isErrorFinishReason()) {
QString errorMessage = message->getErrorMessage();
LOG_MESSAGE(QString("Google AI error: %1").arg(errorMessage));
m_failedRequests.insert(requestId);
emit requestFailed(requestId, errorMessage);
return;
}
}
}
if (chunk.contains("usageMetadata")) {
QJsonObject usageMetadata = chunk["usageMetadata"].toObject();
int thoughtsTokenCount = usageMetadata.value("thoughtsTokenCount").toInt(0);
int candidatesTokenCount = usageMetadata.value("candidatesTokenCount").toInt(0);
int totalTokenCount = usageMetadata.value("totalTokenCount").toInt(0);
if (totalTokenCount > 0) {
LOG_MESSAGE(QString("Google AI tokens: %1 (thoughts: %2, output: %3)")
.arg(totalTokenCount)
.arg(thoughtsTokenCount)
.arg(candidatesTokenCount));
}
}
}
void GoogleAIProvider::emitPendingThinkingBlocks(const QString &requestId)
{
if (!m_messages.contains(requestId))
return;
GoogleMessage *message = m_messages[requestId];
auto thinkingBlocks = message->getCurrentThinkingContent();
if (thinkingBlocks.isEmpty())
return;
int alreadyEmitted = m_emittedThinkingBlocksCount.value(requestId, 0);
int totalBlocks = thinkingBlocks.size();
for (int i = alreadyEmitted; i < totalBlocks; ++i) {
auto thinkingContent = thinkingBlocks[i];
emit thinkingBlockReceived(
requestId,
thinkingContent->thinking(),
thinkingContent->signature());
}
m_emittedThinkingBlocksCount[requestId] = totalBlocks;
}
void GoogleAIProvider::handleMessageComplete(const QString &requestId)
@ -445,6 +574,8 @@ void GoogleAIProvider::cleanupRequest(const LLMCore::RequestID &requestId)
m_dataBuffers.remove(requestId);
m_requestUrls.remove(requestId);
m_originalRequests.remove(requestId);
m_emittedThinkingBlocksCount.remove(requestId);
m_failedRequests.remove(requestId);
m_toolsManager->cleanupRequest(requestId);
}

View File

@ -52,6 +52,7 @@ public:
const LLMCore::RequestID &requestId, const QUrl &url, const QJsonObject &payload) override;
bool supportsTools() const override;
bool supportThinking() const override;
void cancelRequest(const LLMCore::RequestID &requestId) override;
public slots:
@ -69,11 +70,14 @@ private slots:
private:
void processStreamChunk(const QString &requestId, const QJsonObject &chunk);
void handleMessageComplete(const QString &requestId);
void emitPendingThinkingBlocks(const QString &requestId);
void cleanupRequest(const LLMCore::RequestID &requestId);
QHash<LLMCore::RequestID, GoogleMessage *> m_messages;
QHash<LLMCore::RequestID, QUrl> m_requestUrls;
QHash<LLMCore::RequestID, QJsonObject> m_originalRequests;
QHash<LLMCore::RequestID, int> m_emittedThinkingBlocksCount;
QSet<LLMCore::RequestID> m_failedRequests;
Tools::ToolsManager *m_toolsManager;
};

View File

@ -43,12 +43,38 @@ void GoogleMessage::handleContentDelta(const QString &text)
}
}
void GoogleMessage::handleThoughtDelta(const QString &text)
{
if (m_currentBlocks.isEmpty() || !qobject_cast<LLMCore::ThinkingContent *>(m_currentBlocks.last())) {
auto thinkingContent = new LLMCore::ThinkingContent();
thinkingContent->setParent(this);
m_currentBlocks.append(thinkingContent);
}
if (auto thinkingContent = qobject_cast<LLMCore::ThinkingContent *>(m_currentBlocks.last())) {
thinkingContent->appendThinking(text);
}
}
void GoogleMessage::handleThoughtSignature(const QString &signature)
{
for (int i = m_currentBlocks.size() - 1; i >= 0; --i) {
if (auto thinkingContent = qobject_cast<LLMCore::ThinkingContent *>(m_currentBlocks[i])) {
thinkingContent->setSignature(signature);
return;
}
}
auto thinkingContent = new LLMCore::ThinkingContent();
thinkingContent->setParent(this);
thinkingContent->setSignature(signature);
m_currentBlocks.append(thinkingContent);
}
void GoogleMessage::handleFunctionCallStart(const QString &name)
{
m_currentFunctionName = name;
m_pendingFunctionArgs.clear();
LOG_MESSAGE(QString("Google: Starting function call: %1").arg(name));
}
void GoogleMessage::handleFunctionCallArgsDelta(const QString &argsJson)
@ -75,10 +101,6 @@ void GoogleMessage::handleFunctionCallComplete()
toolContent->setParent(this);
m_currentBlocks.append(toolContent);
LOG_MESSAGE(QString("Google: Completed function call: name=%1, args=%2")
.arg(m_currentFunctionName)
.arg(QString::fromUtf8(QJsonDocument(args).toJson(QJsonDocument::Compact))));
m_currentFunctionName.clear();
m_pendingFunctionArgs.clear();
}
@ -87,9 +109,6 @@ void GoogleMessage::handleFinishReason(const QString &reason)
{
m_finishReason = reason;
updateStateFromFinishReason();
LOG_MESSAGE(
QString("Google: Finish reason: %1, state: %2").arg(reason).arg(static_cast<int>(m_state)));
}
QJsonObject GoogleMessage::toProviderFormat() const
@ -110,6 +129,19 @@ QJsonObject GoogleMessage::toProviderFormat() const
functionCall["name"] = tool->name();
functionCall["args"] = tool->input();
parts.append(QJsonObject{{"functionCall", functionCall}});
} else if (auto thinking = qobject_cast<LLMCore::ThinkingContent *>(block)) {
// Include thinking blocks with their text
QJsonObject thinkingPart;
thinkingPart["text"] = thinking->thinking();
thinkingPart["thought"] = true;
parts.append(thinkingPart);
// If there's a signature, add it as a separate part
if (!thinking->signature().isEmpty()) {
QJsonObject signaturePart;
signaturePart["thoughtSignature"] = thinking->signature();
parts.append(signaturePart);
}
}
}
@ -148,6 +180,17 @@ QList<LLMCore::ToolUseContent *> GoogleMessage::getCurrentToolUseContent() const
return toolBlocks;
}
QList<LLMCore::ThinkingContent *> GoogleMessage::getCurrentThinkingContent() const
{
QList<LLMCore::ThinkingContent *> thinkingBlocks;
for (auto block : m_currentBlocks) {
if (auto thinkingContent = qobject_cast<LLMCore::ThinkingContent *>(block)) {
thinkingBlocks.append(thinkingContent);
}
}
return thinkingBlocks;
}
void GoogleMessage::startNewContinuation()
{
LOG_MESSAGE(QString("GoogleMessage: Starting new continuation"));
@ -159,6 +202,34 @@ void GoogleMessage::startNewContinuation()
m_state = LLMCore::MessageState::Building;
}
bool GoogleMessage::isErrorFinishReason() const
{
return m_finishReason == "SAFETY"
|| m_finishReason == "RECITATION"
|| m_finishReason == "MALFORMED_FUNCTION_CALL"
|| m_finishReason == "PROHIBITED_CONTENT"
|| m_finishReason == "SPII"
|| m_finishReason == "OTHER";
}
QString GoogleMessage::getErrorMessage() const
{
if (m_finishReason == "SAFETY") {
return "Response blocked by safety filters";
} else if (m_finishReason == "RECITATION") {
return "Response blocked due to recitation of copyrighted content";
} else if (m_finishReason == "MALFORMED_FUNCTION_CALL") {
return "Model attempted to call a function with malformed arguments. Please try rephrasing your request or disabling tools.";
} else if (m_finishReason == "PROHIBITED_CONTENT") {
return "Response blocked due to prohibited content";
} else if (m_finishReason == "SPII") {
return "Response blocked due to sensitive personally identifiable information";
} else if (m_finishReason == "OTHER") {
return "Request failed due to an unknown reason";
}
return QString();
}
void GoogleMessage::updateStateFromFinishReason()
{
if (m_finishReason == "STOP" || m_finishReason == "MAX_TOKENS") {

View File

@ -35,6 +35,8 @@ public:
explicit GoogleMessage(QObject *parent = nullptr);
void handleContentDelta(const QString &text);
void handleThoughtDelta(const QString &text);
void handleThoughtSignature(const QString &signature);
void handleFunctionCallStart(const QString &name);
void handleFunctionCallArgsDelta(const QString &argsJson);
void handleFunctionCallComplete();
@ -44,9 +46,13 @@ public:
QJsonArray createToolResultParts(const QHash<QString, QString> &toolResults) const;
QList<LLMCore::ToolUseContent *> getCurrentToolUseContent() const;
QList<LLMCore::ThinkingContent *> getCurrentThinkingContent() const;
QList<LLMCore::ContentBlock *> currentBlocks() const { return m_currentBlocks; }
LLMCore::MessageState state() const { return m_state; }
QString finishReason() const { return m_finishReason; }
bool isErrorFinishReason() const;
QString getErrorMessage() const;
void startNewContinuation();
private: