diff --git a/providers/MistralAIProvider.cpp b/providers/MistralAIProvider.cpp
index 8d669d1..99318de 100644
--- a/providers/MistralAIProvider.cpp
+++ b/providers/MistralAIProvider.cpp
@@ -1,21 +1,49 @@
+/*
+ * Copyright (C) 2024-2025 Petr Mironychev
+ *
+ * This file is part of QodeAssist.
+ *
+ * QodeAssist is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * QodeAssist is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with QodeAssist. If not, see .
+ */
+
#include "MistralAIProvider.hpp"
+#include "llmcore/ValidationUtils.hpp"
+#include "logger/Logger.hpp"
#include "settings/ChatAssistantSettings.hpp"
#include "settings/CodeCompletionSettings.hpp"
#include "settings/ProviderSettings.hpp"
+#include
#include
#include
#include
#include
-#include
-
-#include "llmcore/OpenAIMessage.hpp"
-#include "llmcore/ValidationUtils.hpp"
-#include "logger/Logger.hpp"
namespace QodeAssist::Providers {
+MistralAIProvider::MistralAIProvider(QObject *parent)
+ : LLMCore::Provider(parent)
+ , m_toolsManager(new Tools::ToolsManager(this))
+{
+ connect(
+ m_toolsManager,
+ &Tools::ToolsManager::toolExecutionComplete,
+ this,
+ &MistralAIProvider::onToolExecutionComplete);
+}
+
QString MistralAIProvider::name() const
{
return "Mistral AI";
@@ -97,10 +125,12 @@ QList MistralAIProvider::validateRequest(
{"temperature", {}},
{"max_tokens", {}},
{"top_p", {}},
+ {"top_k", {}},
{"frequency_penalty", {}},
{"presence_penalty", {}},
{"stop", QJsonArray{}},
- {"stream", {}}};
+ {"stream", {}},
+ {"tools", {}}};
return LLMCore::ValidationUtils::validateRequestFields(
request, type == LLMCore::TemplateType::FIM ? fimReq : templateReq);
@@ -128,8 +158,12 @@ LLMCore::ProviderID MistralAIProvider::providerID() const
void MistralAIProvider::sendRequest(
const LLMCore::RequestID &requestId, const QUrl &url, const QJsonObject &payload)
{
- m_dataBuffers[requestId].clear();
+ if (!m_messages.contains(requestId)) {
+ m_dataBuffers[requestId].clear();
+ }
+
m_requestUrls[requestId] = url;
+ m_originalRequests[requestId] = payload;
QNetworkRequest networkRequest(url);
prepareNetworkRequest(networkRequest);
@@ -143,57 +177,34 @@ void MistralAIProvider::sendRequest(
emit httpClient()->sendRequest(request);
}
+bool MistralAIProvider::supportsTools() const
+{
+ return true;
+}
+
+void MistralAIProvider::cancelRequest(const LLMCore::RequestID &requestId)
+{
+ LOG_MESSAGE(QString("MistralAIProvider: Cancelling request %1").arg(requestId));
+ LLMCore::Provider::cancelRequest(requestId);
+ cleanupRequest(requestId);
+}
+
void MistralAIProvider::onDataReceived(
const QodeAssist::LLMCore::RequestID &requestId, const QByteArray &data)
{
LLMCore::DataBuffers &buffers = m_dataBuffers[requestId];
QStringList lines = buffers.rawStreamBuffer.processData(data);
- if (data.isEmpty()) {
- return;
- }
-
- bool isDone = false;
- QString tempResponse;
-
for (const QString &line : lines) {
- if (line.trimmed().isEmpty()) {
+ if (line.trimmed().isEmpty() || line == "data: [DONE]") {
continue;
}
- if (line == "data: [DONE]") {
- isDone = true;
- continue;
- }
-
- QJsonObject responseObj = parseEventLine(line);
- if (responseObj.isEmpty())
+ QJsonObject chunk = parseEventLine(line);
+ if (chunk.isEmpty())
continue;
- auto message = LLMCore::OpenAIMessage::fromJson(responseObj);
- if (message.hasError()) {
- LOG_MESSAGE("Error in MistralAI response: " + message.error);
- continue;
- }
-
- QString content = message.getContent();
- if (!content.isEmpty()) {
- tempResponse += content;
- }
-
- if (message.isDone()) {
- isDone = true;
- }
- }
-
- if (!tempResponse.isEmpty()) {
- buffers.responseContent += tempResponse;
- emit partialResponseReceived(requestId, tempResponse);
- }
-
- if (isDone) {
- emit fullResponseReceived(requestId, buffers.responseContent);
- m_dataBuffers.remove(requestId);
+ processStreamChunk(requestId, chunk);
}
}
@@ -203,17 +214,28 @@ void MistralAIProvider::onRequestFinished(
if (!success) {
LOG_MESSAGE(QString("MistralAIProvider request %1 failed: %2").arg(requestId, error));
emit requestFailed(requestId, error);
- } else {
- if (m_dataBuffers.contains(requestId)) {
- const LLMCore::DataBuffers &buffers = m_dataBuffers[requestId];
- if (!buffers.responseContent.isEmpty()) {
- emit fullResponseReceived(requestId, buffers.responseContent);
- }
+ cleanupRequest(requestId);
+ return;
+ }
+
+ if (m_messages.contains(requestId)) {
+ OpenAIMessage *message = m_messages[requestId];
+ if (message->state() == LLMCore::MessageState::RequiresToolExecution) {
+ LOG_MESSAGE(QString("Waiting for tools to complete for %1").arg(requestId));
+ m_dataBuffers.remove(requestId);
+ return;
}
}
- m_dataBuffers.remove(requestId);
- m_requestUrls.remove(requestId);
+ if (m_dataBuffers.contains(requestId)) {
+ const LLMCore::DataBuffers &buffers = m_dataBuffers[requestId];
+ if (!buffers.responseContent.isEmpty()) {
+ LOG_MESSAGE(QString("Emitting full response for %1").arg(requestId));
+ emit fullResponseReceived(requestId, buffers.responseContent);
+ }
+ }
+
+ cleanupRequest(requestId);
}
void MistralAIProvider::prepareRequest(
@@ -228,33 +250,167 @@ void MistralAIProvider::prepareRequest(
prompt->prepareRequest(request, context);
- if (type == LLMCore::RequestType::Chat) {
- auto &settings = Settings::chatAssistantSettings();
-
+ auto applyModelParams = [&request](const auto &settings) {
request["max_tokens"] = settings.maxTokens();
request["temperature"] = settings.temperature();
if (settings.useTopP())
request["top_p"] = settings.topP();
-
- // request["random_seed"] = "";
-
+ if (settings.useTopK())
+ request["top_k"] = settings.topK();
if (settings.useFrequencyPenalty())
request["frequency_penalty"] = settings.frequencyPenalty();
if (settings.usePresencePenalty())
request["presence_penalty"] = settings.presencePenalty();
+ };
+ if (type == LLMCore::RequestType::CodeCompletion) {
+ applyModelParams(Settings::codeCompletionSettings());
} else {
- auto &settings = Settings::codeCompletionSettings();
+ applyModelParams(Settings::chatAssistantSettings());
+ }
- request["max_tokens"] = settings.maxTokens();
- request["temperature"] = settings.temperature();
-
- if (settings.useTopP())
- request["top_p"] = settings.topP();
-
- // request["random_seed"] = "";
+ if (supportsTools() && type == LLMCore::RequestType::Chat
+ && Settings::chatAssistantSettings().useTools()) {
+ auto toolsDefinitions = m_toolsManager->getToolsDefinitions(Tools::ToolSchemaFormat::OpenAI);
+ if (!toolsDefinitions.isEmpty()) {
+ request["tools"] = toolsDefinitions;
+ LOG_MESSAGE(QString("Added %1 tools to Mistral request").arg(toolsDefinitions.size()));
+ }
}
}
+void MistralAIProvider::onToolExecutionComplete(
+ const QString &requestId, const QHash &toolResults)
+{
+ if (!m_messages.contains(requestId) || !m_requestUrls.contains(requestId)) {
+ LOG_MESSAGE(QString("ERROR: Missing data for continuation request %1").arg(requestId));
+ cleanupRequest(requestId);
+ return;
+ }
+
+ LOG_MESSAGE(QString("Tool execution complete for Mistral request %1").arg(requestId));
+
+ OpenAIMessage *message = m_messages[requestId];
+ QJsonObject continuationRequest = m_originalRequests[requestId];
+ QJsonArray messages = continuationRequest["messages"].toArray();
+
+ messages.append(message->toProviderFormat());
+
+ QJsonArray toolResultMessages = message->createToolResultMessages(toolResults);
+ for (const auto &toolMsg : toolResultMessages) {
+ messages.append(toolMsg);
+ }
+
+ continuationRequest["messages"] = messages;
+
+ LOG_MESSAGE(QString("Sending continuation request for %1 with %2 tool results")
+ .arg(requestId)
+ .arg(toolResults.size()));
+
+ sendRequest(requestId, m_requestUrls[requestId], continuationRequest);
+}
+
+void MistralAIProvider::processStreamChunk(const QString &requestId, const QJsonObject &chunk)
+{
+ QJsonArray choices = chunk["choices"].toArray();
+ if (choices.isEmpty()) {
+ return;
+ }
+
+ QJsonObject choice = choices[0].toObject();
+ QJsonObject delta = choice["delta"].toObject();
+ QString finishReason = choice["finish_reason"].toString();
+
+ OpenAIMessage *message = m_messages.value(requestId);
+ if (!message) {
+ message = new OpenAIMessage(this);
+ m_messages[requestId] = message;
+ LOG_MESSAGE(QString("Created NEW OpenAIMessage for Mistral request %1").arg(requestId));
+ }
+
+ if (delta.contains("content") && !delta["content"].isNull()) {
+ QString content = delta["content"].toString();
+ message->handleContentDelta(content);
+
+ LLMCore::DataBuffers &buffers = m_dataBuffers[requestId];
+ buffers.responseContent += content;
+ emit partialResponseReceived(requestId, content);
+ }
+
+ if (delta.contains("tool_calls")) {
+ QJsonArray toolCalls = delta["tool_calls"].toArray();
+ for (const auto &toolCallValue : toolCalls) {
+ QJsonObject toolCall = toolCallValue.toObject();
+ int index = toolCall["index"].toInt();
+
+ if (toolCall.contains("id")) {
+ QString id = toolCall["id"].toString();
+ QJsonObject function = toolCall["function"].toObject();
+ QString name = function["name"].toString();
+ message->handleToolCallStart(index, id, name);
+ }
+
+ if (toolCall.contains("function")) {
+ QJsonObject function = toolCall["function"].toObject();
+ if (function.contains("arguments")) {
+ QString args = function["arguments"].toString();
+ message->handleToolCallDelta(index, args);
+ }
+ }
+ }
+ }
+
+ if (!finishReason.isEmpty() && finishReason != "null") {
+ for (int i = 0; i < 10; ++i) {
+ message->handleToolCallComplete(i);
+ }
+
+ message->handleFinishReason(finishReason);
+ handleMessageComplete(requestId);
+ }
+}
+
+void MistralAIProvider::handleMessageComplete(const QString &requestId)
+{
+ if (!m_messages.contains(requestId))
+ return;
+
+ OpenAIMessage *message = m_messages[requestId];
+
+ if (message->state() == LLMCore::MessageState::RequiresToolExecution) {
+ LOG_MESSAGE(QString("Mistral message requires tool execution for %1").arg(requestId));
+
+ auto toolUseContent = message->getCurrentToolUseContent();
+
+ if (toolUseContent.isEmpty()) {
+ LOG_MESSAGE(QString("No tools to execute for %1").arg(requestId));
+ return;
+ }
+
+ for (auto toolContent : toolUseContent) {
+ m_toolsManager->executeToolCall(
+ requestId, toolContent->id(), toolContent->name(), toolContent->input());
+ }
+
+ } else {
+ LOG_MESSAGE(QString("Mistral message marked as complete for %1").arg(requestId));
+ }
+}
+
+void MistralAIProvider::cleanupRequest(const LLMCore::RequestID &requestId)
+{
+ LOG_MESSAGE(QString("Cleaning up Mistral request %1").arg(requestId));
+
+ if (m_messages.contains(requestId)) {
+ OpenAIMessage *message = m_messages.take(requestId);
+ message->deleteLater();
+ }
+
+ m_dataBuffers.remove(requestId);
+ m_requestUrls.remove(requestId);
+ m_originalRequests.remove(requestId);
+ m_toolsManager->cleanupRequest(requestId);
+}
+
} // namespace QodeAssist::Providers
diff --git a/providers/MistralAIProvider.hpp b/providers/MistralAIProvider.hpp
index 7ce86d2..34fed7a 100644
--- a/providers/MistralAIProvider.hpp
+++ b/providers/MistralAIProvider.hpp
@@ -19,13 +19,18 @@
#pragma once
-#include "llmcore/Provider.hpp"
+#include "OpenAIMessage.hpp"
+#include "tools/ToolsManager.hpp"
+#include
namespace QodeAssist::Providers {
class MistralAIProvider : public LLMCore::Provider
{
+ Q_OBJECT
public:
+ explicit MistralAIProvider(QObject *parent = nullptr);
+
QString name() const override;
QString url() const override;
QString completionEndpoint() const override;
@@ -45,6 +50,9 @@ public:
void sendRequest(
const LLMCore::RequestID &requestId, const QUrl &url, const QJsonObject &payload) override;
+ bool supportsTools() const override;
+ void cancelRequest(const LLMCore::RequestID &requestId) override;
+
public slots:
void onDataReceived(
const QodeAssist::LLMCore::RequestID &requestId, const QByteArray &data) override;
@@ -52,6 +60,20 @@ public slots:
const QodeAssist::LLMCore::RequestID &requestId,
bool success,
const QString &error) override;
+
+private slots:
+ void onToolExecutionComplete(
+ const QString &requestId, const QHash &toolResults);
+
+private:
+ void processStreamChunk(const QString &requestId, const QJsonObject &chunk);
+ void handleMessageComplete(const QString &requestId);
+ void cleanupRequest(const LLMCore::RequestID &requestId);
+
+ QHash m_messages;
+ QHash m_requestUrls;
+ QHash m_originalRequests;
+ Tools::ToolsManager *m_toolsManager;
};
} // namespace QodeAssist::Providers