diff --git a/providers/MistralAIProvider.cpp b/providers/MistralAIProvider.cpp index 8d669d1..99318de 100644 --- a/providers/MistralAIProvider.cpp +++ b/providers/MistralAIProvider.cpp @@ -1,21 +1,49 @@ +/* + * Copyright (C) 2024-2025 Petr Mironychev + * + * This file is part of QodeAssist. + * + * QodeAssist is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * QodeAssist is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with QodeAssist. If not, see . + */ + #include "MistralAIProvider.hpp" +#include "llmcore/ValidationUtils.hpp" +#include "logger/Logger.hpp" #include "settings/ChatAssistantSettings.hpp" #include "settings/CodeCompletionSettings.hpp" #include "settings/ProviderSettings.hpp" +#include #include #include #include #include -#include - -#include "llmcore/OpenAIMessage.hpp" -#include "llmcore/ValidationUtils.hpp" -#include "logger/Logger.hpp" namespace QodeAssist::Providers { +MistralAIProvider::MistralAIProvider(QObject *parent) + : LLMCore::Provider(parent) + , m_toolsManager(new Tools::ToolsManager(this)) +{ + connect( + m_toolsManager, + &Tools::ToolsManager::toolExecutionComplete, + this, + &MistralAIProvider::onToolExecutionComplete); +} + QString MistralAIProvider::name() const { return "Mistral AI"; @@ -97,10 +125,12 @@ QList MistralAIProvider::validateRequest( {"temperature", {}}, {"max_tokens", {}}, {"top_p", {}}, + {"top_k", {}}, {"frequency_penalty", {}}, {"presence_penalty", {}}, {"stop", QJsonArray{}}, - {"stream", {}}}; + {"stream", {}}, + {"tools", {}}}; return LLMCore::ValidationUtils::validateRequestFields( request, type == LLMCore::TemplateType::FIM ? fimReq : templateReq); @@ -128,8 +158,12 @@ LLMCore::ProviderID MistralAIProvider::providerID() const void MistralAIProvider::sendRequest( const LLMCore::RequestID &requestId, const QUrl &url, const QJsonObject &payload) { - m_dataBuffers[requestId].clear(); + if (!m_messages.contains(requestId)) { + m_dataBuffers[requestId].clear(); + } + m_requestUrls[requestId] = url; + m_originalRequests[requestId] = payload; QNetworkRequest networkRequest(url); prepareNetworkRequest(networkRequest); @@ -143,57 +177,34 @@ void MistralAIProvider::sendRequest( emit httpClient()->sendRequest(request); } +bool MistralAIProvider::supportsTools() const +{ + return true; +} + +void MistralAIProvider::cancelRequest(const LLMCore::RequestID &requestId) +{ + LOG_MESSAGE(QString("MistralAIProvider: Cancelling request %1").arg(requestId)); + LLMCore::Provider::cancelRequest(requestId); + cleanupRequest(requestId); +} + void MistralAIProvider::onDataReceived( const QodeAssist::LLMCore::RequestID &requestId, const QByteArray &data) { LLMCore::DataBuffers &buffers = m_dataBuffers[requestId]; QStringList lines = buffers.rawStreamBuffer.processData(data); - if (data.isEmpty()) { - return; - } - - bool isDone = false; - QString tempResponse; - for (const QString &line : lines) { - if (line.trimmed().isEmpty()) { + if (line.trimmed().isEmpty() || line == "data: [DONE]") { continue; } - if (line == "data: [DONE]") { - isDone = true; - continue; - } - - QJsonObject responseObj = parseEventLine(line); - if (responseObj.isEmpty()) + QJsonObject chunk = parseEventLine(line); + if (chunk.isEmpty()) continue; - auto message = LLMCore::OpenAIMessage::fromJson(responseObj); - if (message.hasError()) { - LOG_MESSAGE("Error in MistralAI response: " + message.error); - continue; - } - - QString content = message.getContent(); - if (!content.isEmpty()) { - tempResponse += content; - } - - if (message.isDone()) { - isDone = true; - } - } - - if (!tempResponse.isEmpty()) { - buffers.responseContent += tempResponse; - emit partialResponseReceived(requestId, tempResponse); - } - - if (isDone) { - emit fullResponseReceived(requestId, buffers.responseContent); - m_dataBuffers.remove(requestId); + processStreamChunk(requestId, chunk); } } @@ -203,17 +214,28 @@ void MistralAIProvider::onRequestFinished( if (!success) { LOG_MESSAGE(QString("MistralAIProvider request %1 failed: %2").arg(requestId, error)); emit requestFailed(requestId, error); - } else { - if (m_dataBuffers.contains(requestId)) { - const LLMCore::DataBuffers &buffers = m_dataBuffers[requestId]; - if (!buffers.responseContent.isEmpty()) { - emit fullResponseReceived(requestId, buffers.responseContent); - } + cleanupRequest(requestId); + return; + } + + if (m_messages.contains(requestId)) { + OpenAIMessage *message = m_messages[requestId]; + if (message->state() == LLMCore::MessageState::RequiresToolExecution) { + LOG_MESSAGE(QString("Waiting for tools to complete for %1").arg(requestId)); + m_dataBuffers.remove(requestId); + return; } } - m_dataBuffers.remove(requestId); - m_requestUrls.remove(requestId); + if (m_dataBuffers.contains(requestId)) { + const LLMCore::DataBuffers &buffers = m_dataBuffers[requestId]; + if (!buffers.responseContent.isEmpty()) { + LOG_MESSAGE(QString("Emitting full response for %1").arg(requestId)); + emit fullResponseReceived(requestId, buffers.responseContent); + } + } + + cleanupRequest(requestId); } void MistralAIProvider::prepareRequest( @@ -228,33 +250,167 @@ void MistralAIProvider::prepareRequest( prompt->prepareRequest(request, context); - if (type == LLMCore::RequestType::Chat) { - auto &settings = Settings::chatAssistantSettings(); - + auto applyModelParams = [&request](const auto &settings) { request["max_tokens"] = settings.maxTokens(); request["temperature"] = settings.temperature(); if (settings.useTopP()) request["top_p"] = settings.topP(); - - // request["random_seed"] = ""; - + if (settings.useTopK()) + request["top_k"] = settings.topK(); if (settings.useFrequencyPenalty()) request["frequency_penalty"] = settings.frequencyPenalty(); if (settings.usePresencePenalty()) request["presence_penalty"] = settings.presencePenalty(); + }; + if (type == LLMCore::RequestType::CodeCompletion) { + applyModelParams(Settings::codeCompletionSettings()); } else { - auto &settings = Settings::codeCompletionSettings(); + applyModelParams(Settings::chatAssistantSettings()); + } - request["max_tokens"] = settings.maxTokens(); - request["temperature"] = settings.temperature(); - - if (settings.useTopP()) - request["top_p"] = settings.topP(); - - // request["random_seed"] = ""; + if (supportsTools() && type == LLMCore::RequestType::Chat + && Settings::chatAssistantSettings().useTools()) { + auto toolsDefinitions = m_toolsManager->getToolsDefinitions(Tools::ToolSchemaFormat::OpenAI); + if (!toolsDefinitions.isEmpty()) { + request["tools"] = toolsDefinitions; + LOG_MESSAGE(QString("Added %1 tools to Mistral request").arg(toolsDefinitions.size())); + } } } +void MistralAIProvider::onToolExecutionComplete( + const QString &requestId, const QHash &toolResults) +{ + if (!m_messages.contains(requestId) || !m_requestUrls.contains(requestId)) { + LOG_MESSAGE(QString("ERROR: Missing data for continuation request %1").arg(requestId)); + cleanupRequest(requestId); + return; + } + + LOG_MESSAGE(QString("Tool execution complete for Mistral request %1").arg(requestId)); + + OpenAIMessage *message = m_messages[requestId]; + QJsonObject continuationRequest = m_originalRequests[requestId]; + QJsonArray messages = continuationRequest["messages"].toArray(); + + messages.append(message->toProviderFormat()); + + QJsonArray toolResultMessages = message->createToolResultMessages(toolResults); + for (const auto &toolMsg : toolResultMessages) { + messages.append(toolMsg); + } + + continuationRequest["messages"] = messages; + + LOG_MESSAGE(QString("Sending continuation request for %1 with %2 tool results") + .arg(requestId) + .arg(toolResults.size())); + + sendRequest(requestId, m_requestUrls[requestId], continuationRequest); +} + +void MistralAIProvider::processStreamChunk(const QString &requestId, const QJsonObject &chunk) +{ + QJsonArray choices = chunk["choices"].toArray(); + if (choices.isEmpty()) { + return; + } + + QJsonObject choice = choices[0].toObject(); + QJsonObject delta = choice["delta"].toObject(); + QString finishReason = choice["finish_reason"].toString(); + + OpenAIMessage *message = m_messages.value(requestId); + if (!message) { + message = new OpenAIMessage(this); + m_messages[requestId] = message; + LOG_MESSAGE(QString("Created NEW OpenAIMessage for Mistral request %1").arg(requestId)); + } + + if (delta.contains("content") && !delta["content"].isNull()) { + QString content = delta["content"].toString(); + message->handleContentDelta(content); + + LLMCore::DataBuffers &buffers = m_dataBuffers[requestId]; + buffers.responseContent += content; + emit partialResponseReceived(requestId, content); + } + + if (delta.contains("tool_calls")) { + QJsonArray toolCalls = delta["tool_calls"].toArray(); + for (const auto &toolCallValue : toolCalls) { + QJsonObject toolCall = toolCallValue.toObject(); + int index = toolCall["index"].toInt(); + + if (toolCall.contains("id")) { + QString id = toolCall["id"].toString(); + QJsonObject function = toolCall["function"].toObject(); + QString name = function["name"].toString(); + message->handleToolCallStart(index, id, name); + } + + if (toolCall.contains("function")) { + QJsonObject function = toolCall["function"].toObject(); + if (function.contains("arguments")) { + QString args = function["arguments"].toString(); + message->handleToolCallDelta(index, args); + } + } + } + } + + if (!finishReason.isEmpty() && finishReason != "null") { + for (int i = 0; i < 10; ++i) { + message->handleToolCallComplete(i); + } + + message->handleFinishReason(finishReason); + handleMessageComplete(requestId); + } +} + +void MistralAIProvider::handleMessageComplete(const QString &requestId) +{ + if (!m_messages.contains(requestId)) + return; + + OpenAIMessage *message = m_messages[requestId]; + + if (message->state() == LLMCore::MessageState::RequiresToolExecution) { + LOG_MESSAGE(QString("Mistral message requires tool execution for %1").arg(requestId)); + + auto toolUseContent = message->getCurrentToolUseContent(); + + if (toolUseContent.isEmpty()) { + LOG_MESSAGE(QString("No tools to execute for %1").arg(requestId)); + return; + } + + for (auto toolContent : toolUseContent) { + m_toolsManager->executeToolCall( + requestId, toolContent->id(), toolContent->name(), toolContent->input()); + } + + } else { + LOG_MESSAGE(QString("Mistral message marked as complete for %1").arg(requestId)); + } +} + +void MistralAIProvider::cleanupRequest(const LLMCore::RequestID &requestId) +{ + LOG_MESSAGE(QString("Cleaning up Mistral request %1").arg(requestId)); + + if (m_messages.contains(requestId)) { + OpenAIMessage *message = m_messages.take(requestId); + message->deleteLater(); + } + + m_dataBuffers.remove(requestId); + m_requestUrls.remove(requestId); + m_originalRequests.remove(requestId); + m_toolsManager->cleanupRequest(requestId); +} + } // namespace QodeAssist::Providers diff --git a/providers/MistralAIProvider.hpp b/providers/MistralAIProvider.hpp index 7ce86d2..34fed7a 100644 --- a/providers/MistralAIProvider.hpp +++ b/providers/MistralAIProvider.hpp @@ -19,13 +19,18 @@ #pragma once -#include "llmcore/Provider.hpp" +#include "OpenAIMessage.hpp" +#include "tools/ToolsManager.hpp" +#include namespace QodeAssist::Providers { class MistralAIProvider : public LLMCore::Provider { + Q_OBJECT public: + explicit MistralAIProvider(QObject *parent = nullptr); + QString name() const override; QString url() const override; QString completionEndpoint() const override; @@ -45,6 +50,9 @@ public: void sendRequest( const LLMCore::RequestID &requestId, const QUrl &url, const QJsonObject &payload) override; + bool supportsTools() const override; + void cancelRequest(const LLMCore::RequestID &requestId) override; + public slots: void onDataReceived( const QodeAssist::LLMCore::RequestID &requestId, const QByteArray &data) override; @@ -52,6 +60,20 @@ public slots: const QodeAssist::LLMCore::RequestID &requestId, bool success, const QString &error) override; + +private slots: + void onToolExecutionComplete( + const QString &requestId, const QHash &toolResults); + +private: + void processStreamChunk(const QString &requestId, const QJsonObject &chunk); + void handleMessageComplete(const QString &requestId); + void cleanupRequest(const LLMCore::RequestID &requestId); + + QHash m_messages; + QHash m_requestUrls; + QHash m_originalRequests; + Tools::ToolsManager *m_toolsManager; }; } // namespace QodeAssist::Providers