feat: Add Mistral AI tooling support

This commit is contained in:
Petr Mironychev
2025-10-01 15:58:45 +02:00
parent ea4f8b9df9
commit 1a08eebe92
2 changed files with 247 additions and 69 deletions

View File

@ -1,21 +1,49 @@
/*
* Copyright (C) 2024-2025 Petr Mironychev
*
* This file is part of QodeAssist.
*
* QodeAssist is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* QodeAssist is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with QodeAssist. If not, see <https://www.gnu.org/licenses/>.
*/
#include "MistralAIProvider.hpp" #include "MistralAIProvider.hpp"
#include "llmcore/ValidationUtils.hpp"
#include "logger/Logger.hpp"
#include "settings/ChatAssistantSettings.hpp" #include "settings/ChatAssistantSettings.hpp"
#include "settings/CodeCompletionSettings.hpp" #include "settings/CodeCompletionSettings.hpp"
#include "settings/ProviderSettings.hpp" #include "settings/ProviderSettings.hpp"
#include <QEventLoop>
#include <QJsonArray> #include <QJsonArray>
#include <QJsonDocument> #include <QJsonDocument>
#include <QJsonObject> #include <QJsonObject>
#include <QNetworkReply> #include <QNetworkReply>
#include <QtCore/qeventloop.h>
#include "llmcore/OpenAIMessage.hpp"
#include "llmcore/ValidationUtils.hpp"
#include "logger/Logger.hpp"
namespace QodeAssist::Providers { namespace QodeAssist::Providers {
MistralAIProvider::MistralAIProvider(QObject *parent)
: LLMCore::Provider(parent)
, m_toolsManager(new Tools::ToolsManager(this))
{
connect(
m_toolsManager,
&Tools::ToolsManager::toolExecutionComplete,
this,
&MistralAIProvider::onToolExecutionComplete);
}
QString MistralAIProvider::name() const QString MistralAIProvider::name() const
{ {
return "Mistral AI"; return "Mistral AI";
@ -97,10 +125,12 @@ QList<QString> MistralAIProvider::validateRequest(
{"temperature", {}}, {"temperature", {}},
{"max_tokens", {}}, {"max_tokens", {}},
{"top_p", {}}, {"top_p", {}},
{"top_k", {}},
{"frequency_penalty", {}}, {"frequency_penalty", {}},
{"presence_penalty", {}}, {"presence_penalty", {}},
{"stop", QJsonArray{}}, {"stop", QJsonArray{}},
{"stream", {}}}; {"stream", {}},
{"tools", {}}};
return LLMCore::ValidationUtils::validateRequestFields( return LLMCore::ValidationUtils::validateRequestFields(
request, type == LLMCore::TemplateType::FIM ? fimReq : templateReq); request, type == LLMCore::TemplateType::FIM ? fimReq : templateReq);
@ -128,8 +158,12 @@ LLMCore::ProviderID MistralAIProvider::providerID() const
void MistralAIProvider::sendRequest( void MistralAIProvider::sendRequest(
const LLMCore::RequestID &requestId, const QUrl &url, const QJsonObject &payload) const LLMCore::RequestID &requestId, const QUrl &url, const QJsonObject &payload)
{ {
m_dataBuffers[requestId].clear(); if (!m_messages.contains(requestId)) {
m_dataBuffers[requestId].clear();
}
m_requestUrls[requestId] = url; m_requestUrls[requestId] = url;
m_originalRequests[requestId] = payload;
QNetworkRequest networkRequest(url); QNetworkRequest networkRequest(url);
prepareNetworkRequest(networkRequest); prepareNetworkRequest(networkRequest);
@ -143,57 +177,34 @@ void MistralAIProvider::sendRequest(
emit httpClient()->sendRequest(request); emit httpClient()->sendRequest(request);
} }
bool MistralAIProvider::supportsTools() const
{
return true;
}
void MistralAIProvider::cancelRequest(const LLMCore::RequestID &requestId)
{
LOG_MESSAGE(QString("MistralAIProvider: Cancelling request %1").arg(requestId));
LLMCore::Provider::cancelRequest(requestId);
cleanupRequest(requestId);
}
void MistralAIProvider::onDataReceived( void MistralAIProvider::onDataReceived(
const QodeAssist::LLMCore::RequestID &requestId, const QByteArray &data) const QodeAssist::LLMCore::RequestID &requestId, const QByteArray &data)
{ {
LLMCore::DataBuffers &buffers = m_dataBuffers[requestId]; LLMCore::DataBuffers &buffers = m_dataBuffers[requestId];
QStringList lines = buffers.rawStreamBuffer.processData(data); QStringList lines = buffers.rawStreamBuffer.processData(data);
if (data.isEmpty()) {
return;
}
bool isDone = false;
QString tempResponse;
for (const QString &line : lines) { for (const QString &line : lines) {
if (line.trimmed().isEmpty()) { if (line.trimmed().isEmpty() || line == "data: [DONE]") {
continue; continue;
} }
if (line == "data: [DONE]") { QJsonObject chunk = parseEventLine(line);
isDone = true; if (chunk.isEmpty())
continue;
}
QJsonObject responseObj = parseEventLine(line);
if (responseObj.isEmpty())
continue; continue;
auto message = LLMCore::OpenAIMessage::fromJson(responseObj); processStreamChunk(requestId, chunk);
if (message.hasError()) {
LOG_MESSAGE("Error in MistralAI response: " + message.error);
continue;
}
QString content = message.getContent();
if (!content.isEmpty()) {
tempResponse += content;
}
if (message.isDone()) {
isDone = true;
}
}
if (!tempResponse.isEmpty()) {
buffers.responseContent += tempResponse;
emit partialResponseReceived(requestId, tempResponse);
}
if (isDone) {
emit fullResponseReceived(requestId, buffers.responseContent);
m_dataBuffers.remove(requestId);
} }
} }
@ -203,17 +214,28 @@ void MistralAIProvider::onRequestFinished(
if (!success) { if (!success) {
LOG_MESSAGE(QString("MistralAIProvider request %1 failed: %2").arg(requestId, error)); LOG_MESSAGE(QString("MistralAIProvider request %1 failed: %2").arg(requestId, error));
emit requestFailed(requestId, error); emit requestFailed(requestId, error);
} else { cleanupRequest(requestId);
if (m_dataBuffers.contains(requestId)) { return;
const LLMCore::DataBuffers &buffers = m_dataBuffers[requestId]; }
if (!buffers.responseContent.isEmpty()) {
emit fullResponseReceived(requestId, buffers.responseContent); if (m_messages.contains(requestId)) {
} OpenAIMessage *message = m_messages[requestId];
if (message->state() == LLMCore::MessageState::RequiresToolExecution) {
LOG_MESSAGE(QString("Waiting for tools to complete for %1").arg(requestId));
m_dataBuffers.remove(requestId);
return;
} }
} }
m_dataBuffers.remove(requestId); if (m_dataBuffers.contains(requestId)) {
m_requestUrls.remove(requestId); const LLMCore::DataBuffers &buffers = m_dataBuffers[requestId];
if (!buffers.responseContent.isEmpty()) {
LOG_MESSAGE(QString("Emitting full response for %1").arg(requestId));
emit fullResponseReceived(requestId, buffers.responseContent);
}
}
cleanupRequest(requestId);
} }
void MistralAIProvider::prepareRequest( void MistralAIProvider::prepareRequest(
@ -228,33 +250,167 @@ void MistralAIProvider::prepareRequest(
prompt->prepareRequest(request, context); prompt->prepareRequest(request, context);
if (type == LLMCore::RequestType::Chat) { auto applyModelParams = [&request](const auto &settings) {
auto &settings = Settings::chatAssistantSettings();
request["max_tokens"] = settings.maxTokens(); request["max_tokens"] = settings.maxTokens();
request["temperature"] = settings.temperature(); request["temperature"] = settings.temperature();
if (settings.useTopP()) if (settings.useTopP())
request["top_p"] = settings.topP(); request["top_p"] = settings.topP();
if (settings.useTopK())
// request["random_seed"] = ""; request["top_k"] = settings.topK();
if (settings.useFrequencyPenalty()) if (settings.useFrequencyPenalty())
request["frequency_penalty"] = settings.frequencyPenalty(); request["frequency_penalty"] = settings.frequencyPenalty();
if (settings.usePresencePenalty()) if (settings.usePresencePenalty())
request["presence_penalty"] = settings.presencePenalty(); request["presence_penalty"] = settings.presencePenalty();
};
if (type == LLMCore::RequestType::CodeCompletion) {
applyModelParams(Settings::codeCompletionSettings());
} else { } else {
auto &settings = Settings::codeCompletionSettings(); applyModelParams(Settings::chatAssistantSettings());
}
request["max_tokens"] = settings.maxTokens(); if (supportsTools() && type == LLMCore::RequestType::Chat
request["temperature"] = settings.temperature(); && Settings::chatAssistantSettings().useTools()) {
auto toolsDefinitions = m_toolsManager->getToolsDefinitions(Tools::ToolSchemaFormat::OpenAI);
if (settings.useTopP()) if (!toolsDefinitions.isEmpty()) {
request["top_p"] = settings.topP(); request["tools"] = toolsDefinitions;
LOG_MESSAGE(QString("Added %1 tools to Mistral request").arg(toolsDefinitions.size()));
// request["random_seed"] = ""; }
} }
} }
void MistralAIProvider::onToolExecutionComplete(
const QString &requestId, const QHash<QString, QString> &toolResults)
{
if (!m_messages.contains(requestId) || !m_requestUrls.contains(requestId)) {
LOG_MESSAGE(QString("ERROR: Missing data for continuation request %1").arg(requestId));
cleanupRequest(requestId);
return;
}
LOG_MESSAGE(QString("Tool execution complete for Mistral request %1").arg(requestId));
OpenAIMessage *message = m_messages[requestId];
QJsonObject continuationRequest = m_originalRequests[requestId];
QJsonArray messages = continuationRequest["messages"].toArray();
messages.append(message->toProviderFormat());
QJsonArray toolResultMessages = message->createToolResultMessages(toolResults);
for (const auto &toolMsg : toolResultMessages) {
messages.append(toolMsg);
}
continuationRequest["messages"] = messages;
LOG_MESSAGE(QString("Sending continuation request for %1 with %2 tool results")
.arg(requestId)
.arg(toolResults.size()));
sendRequest(requestId, m_requestUrls[requestId], continuationRequest);
}
void MistralAIProvider::processStreamChunk(const QString &requestId, const QJsonObject &chunk)
{
QJsonArray choices = chunk["choices"].toArray();
if (choices.isEmpty()) {
return;
}
QJsonObject choice = choices[0].toObject();
QJsonObject delta = choice["delta"].toObject();
QString finishReason = choice["finish_reason"].toString();
OpenAIMessage *message = m_messages.value(requestId);
if (!message) {
message = new OpenAIMessage(this);
m_messages[requestId] = message;
LOG_MESSAGE(QString("Created NEW OpenAIMessage for Mistral request %1").arg(requestId));
}
if (delta.contains("content") && !delta["content"].isNull()) {
QString content = delta["content"].toString();
message->handleContentDelta(content);
LLMCore::DataBuffers &buffers = m_dataBuffers[requestId];
buffers.responseContent += content;
emit partialResponseReceived(requestId, content);
}
if (delta.contains("tool_calls")) {
QJsonArray toolCalls = delta["tool_calls"].toArray();
for (const auto &toolCallValue : toolCalls) {
QJsonObject toolCall = toolCallValue.toObject();
int index = toolCall["index"].toInt();
if (toolCall.contains("id")) {
QString id = toolCall["id"].toString();
QJsonObject function = toolCall["function"].toObject();
QString name = function["name"].toString();
message->handleToolCallStart(index, id, name);
}
if (toolCall.contains("function")) {
QJsonObject function = toolCall["function"].toObject();
if (function.contains("arguments")) {
QString args = function["arguments"].toString();
message->handleToolCallDelta(index, args);
}
}
}
}
if (!finishReason.isEmpty() && finishReason != "null") {
for (int i = 0; i < 10; ++i) {
message->handleToolCallComplete(i);
}
message->handleFinishReason(finishReason);
handleMessageComplete(requestId);
}
}
void MistralAIProvider::handleMessageComplete(const QString &requestId)
{
if (!m_messages.contains(requestId))
return;
OpenAIMessage *message = m_messages[requestId];
if (message->state() == LLMCore::MessageState::RequiresToolExecution) {
LOG_MESSAGE(QString("Mistral message requires tool execution for %1").arg(requestId));
auto toolUseContent = message->getCurrentToolUseContent();
if (toolUseContent.isEmpty()) {
LOG_MESSAGE(QString("No tools to execute for %1").arg(requestId));
return;
}
for (auto toolContent : toolUseContent) {
m_toolsManager->executeToolCall(
requestId, toolContent->id(), toolContent->name(), toolContent->input());
}
} else {
LOG_MESSAGE(QString("Mistral message marked as complete for %1").arg(requestId));
}
}
void MistralAIProvider::cleanupRequest(const LLMCore::RequestID &requestId)
{
LOG_MESSAGE(QString("Cleaning up Mistral request %1").arg(requestId));
if (m_messages.contains(requestId)) {
OpenAIMessage *message = m_messages.take(requestId);
message->deleteLater();
}
m_dataBuffers.remove(requestId);
m_requestUrls.remove(requestId);
m_originalRequests.remove(requestId);
m_toolsManager->cleanupRequest(requestId);
}
} // namespace QodeAssist::Providers } // namespace QodeAssist::Providers

View File

@ -19,13 +19,18 @@
#pragma once #pragma once
#include "llmcore/Provider.hpp" #include "OpenAIMessage.hpp"
#include "tools/ToolsManager.hpp"
#include <llmcore/Provider.hpp>
namespace QodeAssist::Providers { namespace QodeAssist::Providers {
class MistralAIProvider : public LLMCore::Provider class MistralAIProvider : public LLMCore::Provider
{ {
Q_OBJECT
public: public:
explicit MistralAIProvider(QObject *parent = nullptr);
QString name() const override; QString name() const override;
QString url() const override; QString url() const override;
QString completionEndpoint() const override; QString completionEndpoint() const override;
@ -45,6 +50,9 @@ public:
void sendRequest( void sendRequest(
const LLMCore::RequestID &requestId, const QUrl &url, const QJsonObject &payload) override; const LLMCore::RequestID &requestId, const QUrl &url, const QJsonObject &payload) override;
bool supportsTools() const override;
void cancelRequest(const LLMCore::RequestID &requestId) override;
public slots: public slots:
void onDataReceived( void onDataReceived(
const QodeAssist::LLMCore::RequestID &requestId, const QByteArray &data) override; const QodeAssist::LLMCore::RequestID &requestId, const QByteArray &data) override;
@ -52,6 +60,20 @@ public slots:
const QodeAssist::LLMCore::RequestID &requestId, const QodeAssist::LLMCore::RequestID &requestId,
bool success, bool success,
const QString &error) override; const QString &error) override;
private slots:
void onToolExecutionComplete(
const QString &requestId, const QHash<QString, QString> &toolResults);
private:
void processStreamChunk(const QString &requestId, const QJsonObject &chunk);
void handleMessageComplete(const QString &requestId);
void cleanupRequest(const LLMCore::RequestID &requestId);
QHash<LLMCore::RequestID, OpenAIMessage *> m_messages;
QHash<LLMCore::RequestID, QUrl> m_requestUrls;
QHash<LLMCore::RequestID, QJsonObject> m_originalRequests;
Tools::ToolsManager *m_toolsManager;
}; };
} // namespace QodeAssist::Providers } // namespace QodeAssist::Providers