diff --git a/providers/LlamaCppProvider.cpp b/providers/LlamaCppProvider.cpp index eb8b578..6a5a005 100644 --- a/providers/LlamaCppProvider.cpp +++ b/providers/LlamaCppProvider.cpp @@ -19,20 +19,30 @@ #include "LlamaCppProvider.hpp" +#include "llmcore/ValidationUtils.hpp" +#include "logger/Logger.hpp" +#include "settings/ChatAssistantSettings.hpp" +#include "settings/CodeCompletionSettings.hpp" + #include #include #include #include #include -#include "llmcore/OpenAIMessage.hpp" -#include "llmcore/ValidationUtils.hpp" -#include "logger/Logger.hpp" -#include "settings/ChatAssistantSettings.hpp" -#include "settings/CodeCompletionSettings.hpp" - namespace QodeAssist::Providers { +LlamaCppProvider::LlamaCppProvider(QObject *parent) + : LLMCore::Provider(parent) + , m_toolsManager(new Tools::ToolsManager(this)) +{ + connect( + m_toolsManager, + &Tools::ToolsManager::toolExecutionComplete, + this, + &LlamaCppProvider::onToolExecutionComplete); +} + QString LlamaCppProvider::name() const { return "llama.cpp"; @@ -89,6 +99,15 @@ void LlamaCppProvider::prepareRequest( } else { applyModelParams(Settings::chatAssistantSettings()); } + + if (supportsTools() && type == LLMCore::RequestType::Chat + && Settings::chatAssistantSettings().useTools()) { + auto toolsDefinitions = m_toolsManager->getToolsDefinitions(Tools::ToolSchemaFormat::OpenAI); + if (!toolsDefinitions.isEmpty()) { + request["tools"] = toolsDefinitions; + LOG_MESSAGE(QString("Added %1 tools to llama.cpp request").arg(toolsDefinitions.size())); + } + } } QList LlamaCppProvider::getInstalledModels(const QString &url) @@ -127,7 +146,8 @@ QList LlamaCppProvider::validateRequest( {"frequency_penalty", {}}, {"presence_penalty", {}}, {"stop", QJsonArray{}}, - {"stream", {}}}; + {"stream", {}}, + {"tools", {}}}; return LLMCore::ValidationUtils::validateRequestFields(request, chatReq); } @@ -151,8 +171,12 @@ LLMCore::ProviderID LlamaCppProvider::providerID() const void LlamaCppProvider::sendRequest( const LLMCore::RequestID &requestId, const QUrl &url, const QJsonObject &payload) { - m_dataBuffers[requestId].clear(); + if (!m_messages.contains(requestId)) { + m_dataBuffers[requestId].clear(); + } + m_requestUrls[requestId] = url; + m_originalRequests[requestId] = payload; QNetworkRequest networkRequest(url); prepareNetworkRequest(networkRequest); @@ -166,69 +190,46 @@ void LlamaCppProvider::sendRequest( emit httpClient()->sendRequest(request); } +bool LlamaCppProvider::supportsTools() const +{ + return true; +} + +void LlamaCppProvider::cancelRequest(const LLMCore::RequestID &requestId) +{ + LOG_MESSAGE(QString("LlamaCppProvider: Cancelling request %1").arg(requestId)); + LLMCore::Provider::cancelRequest(requestId); + cleanupRequest(requestId); +} + void LlamaCppProvider::onDataReceived( const QodeAssist::LLMCore::RequestID &requestId, const QByteArray &data) { LLMCore::DataBuffers &buffers = m_dataBuffers[requestId]; QStringList lines = buffers.rawStreamBuffer.processData(data); - if (data.isEmpty()) { - return; - } - - bool isDone = data.contains("\"stop\":true") || data.contains("data: [DONE]"); - QString tempResponse; - for (const QString &line : lines) { - if (line.trimmed().isEmpty()) { + if (line.trimmed().isEmpty() || line == "data: [DONE]") { continue; } - if (line == "data: [DONE]") { - isDone = true; + QJsonObject chunk = parseEventLine(line); + if (chunk.isEmpty()) continue; - } - QJsonObject obj = parseEventLine(line); - if (obj.isEmpty()) - continue; - QString content; - - if (obj.contains("content")) { - content = obj["content"].toString(); + if (chunk.contains("content")) { + QString content = chunk["content"].toString(); if (!content.isEmpty()) { - tempResponse += content; + buffers.responseContent += content; + emit partialResponseReceived(requestId, content); } - } else if (obj.contains("choices")) { - auto message = LLMCore::OpenAIMessage::fromJson(obj); - if (message.hasError()) { - LOG_MESSAGE("Error in llama.cpp response: " + message.error); - continue; - } - - content = message.getContent(); - if (!content.isEmpty()) { - tempResponse += content; - } - - if (message.isDone()) { - isDone = true; + if (chunk["stop"].toBool()) { + emit fullResponseReceived(requestId, buffers.responseContent); + m_dataBuffers.remove(requestId); } + } else if (chunk.contains("choices")) { + processStreamChunk(requestId, chunk); } - - if (obj["stop"].toBool()) { - isDone = true; - } - } - - if (!tempResponse.isEmpty()) { - buffers.responseContent += tempResponse; - emit partialResponseReceived(requestId, tempResponse); - } - - if (isDone) { - emit fullResponseReceived(requestId, buffers.responseContent); - m_dataBuffers.remove(requestId); } } @@ -238,17 +239,161 @@ void LlamaCppProvider::onRequestFinished( if (!success) { LOG_MESSAGE(QString("LlamaCppProvider request %1 failed: %2").arg(requestId, error)); emit requestFailed(requestId, error); - } else { - if (m_dataBuffers.contains(requestId)) { - const LLMCore::DataBuffers &buffers = m_dataBuffers[requestId]; - if (!buffers.responseContent.isEmpty()) { - emit fullResponseReceived(requestId, buffers.responseContent); + cleanupRequest(requestId); + return; + } + + if (m_messages.contains(requestId)) { + OpenAIMessage *message = m_messages[requestId]; + if (message->state() == LLMCore::MessageState::RequiresToolExecution) { + LOG_MESSAGE(QString("Waiting for tools to complete for %1").arg(requestId)); + m_dataBuffers.remove(requestId); + return; + } + } + + if (m_dataBuffers.contains(requestId)) { + const LLMCore::DataBuffers &buffers = m_dataBuffers[requestId]; + if (!buffers.responseContent.isEmpty()) { + LOG_MESSAGE(QString("Emitting full response for %1").arg(requestId)); + emit fullResponseReceived(requestId, buffers.responseContent); + } + } + + cleanupRequest(requestId); +} + +void LlamaCppProvider::onToolExecutionComplete( + const QString &requestId, const QHash &toolResults) +{ + if (!m_messages.contains(requestId) || !m_requestUrls.contains(requestId)) { + LOG_MESSAGE(QString("ERROR: Missing data for continuation request %1").arg(requestId)); + cleanupRequest(requestId); + return; + } + + LOG_MESSAGE(QString("Tool execution complete for llama.cpp request %1").arg(requestId)); + + OpenAIMessage *message = m_messages[requestId]; + QJsonObject continuationRequest = m_originalRequests[requestId]; + QJsonArray messages = continuationRequest["messages"].toArray(); + + messages.append(message->toProviderFormat()); + + QJsonArray toolResultMessages = message->createToolResultMessages(toolResults); + for (const auto &toolMsg : toolResultMessages) { + messages.append(toolMsg); + } + + continuationRequest["messages"] = messages; + + LOG_MESSAGE(QString("Sending continuation request for %1 with %2 tool results") + .arg(requestId) + .arg(toolResults.size())); + + sendRequest(requestId, m_requestUrls[requestId], continuationRequest); +} + +void LlamaCppProvider::processStreamChunk(const QString &requestId, const QJsonObject &chunk) +{ + QJsonArray choices = chunk["choices"].toArray(); + if (choices.isEmpty()) { + return; + } + + QJsonObject choice = choices[0].toObject(); + QJsonObject delta = choice["delta"].toObject(); + QString finishReason = choice["finish_reason"].toString(); + + OpenAIMessage *message = m_messages.value(requestId); + if (!message) { + message = new OpenAIMessage(this); + m_messages[requestId] = message; + LOG_MESSAGE(QString("Created NEW OpenAIMessage for llama.cpp request %1").arg(requestId)); + } + + if (delta.contains("content") && !delta["content"].isNull()) { + QString content = delta["content"].toString(); + message->handleContentDelta(content); + + LLMCore::DataBuffers &buffers = m_dataBuffers[requestId]; + buffers.responseContent += content; + emit partialResponseReceived(requestId, content); + } + + if (delta.contains("tool_calls")) { + QJsonArray toolCalls = delta["tool_calls"].toArray(); + for (const auto &toolCallValue : toolCalls) { + QJsonObject toolCall = toolCallValue.toObject(); + int index = toolCall["index"].toInt(); + + if (toolCall.contains("id")) { + QString id = toolCall["id"].toString(); + QJsonObject function = toolCall["function"].toObject(); + QString name = function["name"].toString(); + message->handleToolCallStart(index, id, name); + } + + if (toolCall.contains("function")) { + QJsonObject function = toolCall["function"].toObject(); + if (function.contains("arguments")) { + QString args = function["arguments"].toString(); + message->handleToolCallDelta(index, args); + } } } } + if (!finishReason.isEmpty() && finishReason != "null") { + for (int i = 0; i < 10; ++i) { + message->handleToolCallComplete(i); + } + + message->handleFinishReason(finishReason); + handleMessageComplete(requestId); + } +} + +void LlamaCppProvider::handleMessageComplete(const QString &requestId) +{ + if (!m_messages.contains(requestId)) + return; + + OpenAIMessage *message = m_messages[requestId]; + + if (message->state() == LLMCore::MessageState::RequiresToolExecution) { + LOG_MESSAGE(QString("llama.cpp message requires tool execution for %1").arg(requestId)); + + auto toolUseContent = message->getCurrentToolUseContent(); + + if (toolUseContent.isEmpty()) { + LOG_MESSAGE(QString("No tools to execute for %1").arg(requestId)); + return; + } + + for (auto toolContent : toolUseContent) { + m_toolsManager->executeToolCall( + requestId, toolContent->id(), toolContent->name(), toolContent->input()); + } + + } else { + LOG_MESSAGE(QString("llama.cpp message marked as complete for %1").arg(requestId)); + } +} + +void LlamaCppProvider::cleanupRequest(const LLMCore::RequestID &requestId) +{ + LOG_MESSAGE(QString("Cleaning up llama.cpp request %1").arg(requestId)); + + if (m_messages.contains(requestId)) { + OpenAIMessage *message = m_messages.take(requestId); + message->deleteLater(); + } + m_dataBuffers.remove(requestId); m_requestUrls.remove(requestId); + m_originalRequests.remove(requestId); + m_toolsManager->cleanupRequest(requestId); } } // namespace QodeAssist::Providers diff --git a/providers/LlamaCppProvider.hpp b/providers/LlamaCppProvider.hpp index a912f1c..ec6135e 100644 --- a/providers/LlamaCppProvider.hpp +++ b/providers/LlamaCppProvider.hpp @@ -19,13 +19,18 @@ #pragma once -#include "llmcore/Provider.hpp" +#include "OpenAIMessage.hpp" +#include "tools/ToolsManager.hpp" +#include namespace QodeAssist::Providers { class LlamaCppProvider : public LLMCore::Provider { + Q_OBJECT public: + explicit LlamaCppProvider(QObject *parent = nullptr); + QString name() const override; QString url() const override; QString completionEndpoint() const override; @@ -45,6 +50,9 @@ public: void sendRequest( const LLMCore::RequestID &requestId, const QUrl &url, const QJsonObject &payload) override; + bool supportsTools() const override; + void cancelRequest(const LLMCore::RequestID &requestId) override; + public slots: void onDataReceived( const QodeAssist::LLMCore::RequestID &requestId, const QByteArray &data) override; @@ -52,6 +60,20 @@ public slots: const QodeAssist::LLMCore::RequestID &requestId, bool success, const QString &error) override; + +private slots: + void onToolExecutionComplete( + const QString &requestId, const QHash &toolResults); + +private: + void processStreamChunk(const QString &requestId, const QJsonObject &chunk); + void handleMessageComplete(const QString &requestId); + void cleanupRequest(const LLMCore::RequestID &requestId); + + QHash m_messages; + QHash m_requestUrls; + QHash m_originalRequests; + Tools::ToolsManager *m_toolsManager; }; } // namespace QodeAssist::Providers