From 7376a11a051362a5f3dac352e0c7c07f9ae249db Mon Sep 17 00:00:00 2001 From: Petr Mironychev <9195189+Palm1r@users.noreply.github.com> Date: Sun, 15 Dec 2024 02:08:35 +0100 Subject: [PATCH] :sparkles: feat: Add request validator --- ChatView/ClientInterface.cpp | 7 ++++ LLMClientInterface.cpp | 6 ++++ llmcore/CMakeLists.txt | 1 + llmcore/Provider.hpp | 5 ++- llmcore/ValidationUtils.cpp | 57 ++++++++++++++++++++++++++++++ llmcore/ValidationUtils.hpp | 41 +++++++++++++++++++++ providers/LMStudioProvider.cpp | 19 ++++++++++ providers/LMStudioProvider.hpp | 1 + providers/OllamaProvider.cpp | 37 +++++++++++++++++++ providers/OllamaProvider.hpp | 1 + providers/OpenAICompatProvider.cpp | 19 ++++++++++ providers/OpenAICompatProvider.hpp | 1 + 12 files changed, 194 insertions(+), 1 deletion(-) create mode 100644 llmcore/ValidationUtils.cpp create mode 100644 llmcore/ValidationUtils.hpp diff --git a/ChatView/ClientInterface.cpp b/ChatView/ClientInterface.cpp index 139daae..06b7554 100644 --- a/ChatView/ClientInterface.cpp +++ b/ChatView/ClientInterface.cpp @@ -131,6 +131,13 @@ void ClientInterface::sendMessage(const QString &message, bool includeCurrentFil QJsonObject request; request["id"] = QUuid::createUuid().toString(); + auto errors = config.provider->validateRequest(config.providerRequest, promptTemplate->type()); + if (!errors.isEmpty()) { + LOG_MESSAGE("Validate errors for chat request:"); + LOG_MESSAGES(errors); + return; + } + m_requestHandler->sendLLMRequest(config, request); } diff --git a/LLMClientInterface.cpp b/LLMClientInterface.cpp index 329070b..9684c3e 100644 --- a/LLMClientInterface.cpp +++ b/LLMClientInterface.cpp @@ -195,6 +195,12 @@ void LLMClientInterface::handleCompletion(const QJsonObject &request) config.promptTemplate->prepareRequest(config.providerRequest, updatedContext); config.provider->prepareRequest(config.providerRequest, LLMCore::RequestType::Fim); + auto errors = config.provider->validateRequest(config.providerRequest, promptTemplate->type()); + if (!errors.isEmpty()) { + LOG_MESSAGE("Validate errors for fim request:"); + LOG_MESSAGES(errors); + return; + } m_requestHandler.sendLLMRequest(config, request); } diff --git a/llmcore/CMakeLists.txt b/llmcore/CMakeLists.txt index b4048b6..4a8e0d9 100644 --- a/llmcore/CMakeLists.txt +++ b/llmcore/CMakeLists.txt @@ -9,6 +9,7 @@ add_library(LLMCore STATIC RequestHandler.hpp RequestHandler.cpp OllamaMessage.hpp OllamaMessage.cpp OpenAIMessage.hpp OpenAIMessage.cpp + ValidationUtils.hpp ValidationUtils.cpp ) target_link_libraries(LLMCore diff --git a/llmcore/Provider.hpp b/llmcore/Provider.hpp index 481725e..ae13019 100644 --- a/llmcore/Provider.hpp +++ b/llmcore/Provider.hpp @@ -20,9 +20,11 @@ #pragma once #include -#include "RequestType.hpp" #include +#include "PromptTemplate.hpp" +#include "RequestType.hpp" + class QNetworkReply; class QJsonObject; @@ -42,6 +44,7 @@ public: virtual void prepareRequest(QJsonObject &request, RequestType type) = 0; virtual bool handleResponse(QNetworkReply *reply, QString &accumulatedResponse) = 0; virtual QList getInstalledModels(const QString &url) = 0; + virtual QList validateRequest(const QJsonObject &request, TemplateType type) = 0; }; } // namespace QodeAssist::LLMCore diff --git a/llmcore/ValidationUtils.cpp b/llmcore/ValidationUtils.cpp new file mode 100644 index 0000000..844c3b4 --- /dev/null +++ b/llmcore/ValidationUtils.cpp @@ -0,0 +1,57 @@ +/* + * Copyright (C) 2024 Petr Mironychev + * + * This file is part of QodeAssist. + * + * QodeAssist is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * QodeAssist is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with QodeAssist. If not, see . + */ + +#include "ValidationUtils.hpp" + +#include + +namespace QodeAssist::LLMCore { + +QStringList ValidationUtils::validateRequestFields( + const QJsonObject &request, const QJsonObject &templateObj) +{ + QStringList errors; + validateFields(request, templateObj, errors); + validateNestedObjects(request, templateObj, errors); + return errors; +} + +void ValidationUtils::validateFields( + const QJsonObject &request, const QJsonObject &templateObj, QStringList &errors) +{ + for (auto it = request.begin(); it != request.end(); ++it) { + if (!templateObj.contains(it.key())) { + errors << QString("unknown field '%1'").arg(it.key()); + } + } +} + +void ValidationUtils::validateNestedObjects( + const QJsonObject &request, const QJsonObject &templateObj, QStringList &errors) +{ + for (auto it = request.begin(); it != request.end(); ++it) { + if (templateObj.contains(it.key()) && it.value().isObject() + && templateObj[it.key()].isObject()) { + validateFields(it.value().toObject(), templateObj[it.key()].toObject(), errors); + validateNestedObjects(it.value().toObject(), templateObj[it.key()].toObject(), errors); + } + } +} + +} // namespace QodeAssist::LLMCore diff --git a/llmcore/ValidationUtils.hpp b/llmcore/ValidationUtils.hpp new file mode 100644 index 0000000..543765a --- /dev/null +++ b/llmcore/ValidationUtils.hpp @@ -0,0 +1,41 @@ +/* + * Copyright (C) 2024 Petr Mironychev + * + * This file is part of QodeAssist. + * + * QodeAssist is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * QodeAssist is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with QodeAssist. If not, see . + */ + +#pragma once + +#include +#include + +namespace QodeAssist::LLMCore { + +class ValidationUtils +{ +public: + static QStringList validateRequestFields( + const QJsonObject &request, const QJsonObject &templateObj); + +private: + static void validateFields( + const QJsonObject &request, const QJsonObject &templateObj, QStringList &errors); + + static void validateNestedObjects( + const QJsonObject &request, const QJsonObject &templateObj, QStringList &errors); +}; + +} // namespace QodeAssist::LLMCore diff --git a/providers/LMStudioProvider.cpp b/providers/LMStudioProvider.cpp index 68de130..f16ad8d 100644 --- a/providers/LMStudioProvider.cpp +++ b/providers/LMStudioProvider.cpp @@ -26,6 +26,7 @@ #include #include "llmcore/OpenAIMessage.hpp" +#include "llmcore/ValidationUtils.hpp" #include "logger/Logger.hpp" #include "settings/ChatAssistantSettings.hpp" #include "settings/CodeCompletionSettings.hpp" @@ -169,4 +170,22 @@ QList LMStudioProvider::getInstalledModels(const QString &url) return models; } +QList LMStudioProvider::validateRequest( + const QJsonObject &request, LLMCore::TemplateType type) +{ + const auto templateReq = QJsonObject{ + {"model", {}}, + {"messages", QJsonArray{{QJsonObject{{"role", {}}, {"content", {}}}}}}, + {"temperature", {}}, + {"max_tokens", {}}, + {"top_p", {}}, + {"top_k", {}}, + {"frequency_penalty", {}}, + {"presence_penalty", {}}, + {"stop", QJsonArray{}}, + {"stream", {}}}; + + return LLMCore::ValidationUtils::validateRequestFields(request, templateReq); +} + } // namespace QodeAssist::Providers diff --git a/providers/LMStudioProvider.hpp b/providers/LMStudioProvider.hpp index bf29579..ae9d355 100644 --- a/providers/LMStudioProvider.hpp +++ b/providers/LMStudioProvider.hpp @@ -36,6 +36,7 @@ public: void prepareRequest(QJsonObject &request, LLMCore::RequestType type) override; bool handleResponse(QNetworkReply *reply, QString &accumulatedResponse) override; QList getInstalledModels(const QString &url) override; + QList validateRequest(const QJsonObject &request, LLMCore::TemplateType type) override; }; } // namespace QodeAssist::Providers diff --git a/providers/OllamaProvider.cpp b/providers/OllamaProvider.cpp index 95fb692..58a0c97 100644 --- a/providers/OllamaProvider.cpp +++ b/providers/OllamaProvider.cpp @@ -26,6 +26,7 @@ #include #include "llmcore/OllamaMessage.hpp" +#include "llmcore/ValidationUtils.hpp" #include "logger/Logger.hpp" #include "settings/ChatAssistantSettings.hpp" #include "settings/CodeCompletionSettings.hpp" @@ -137,4 +138,40 @@ QList OllamaProvider::getInstalledModels(const QString &url) return models; } +QList OllamaProvider::validateRequest(const QJsonObject &request, LLMCore::TemplateType type) +{ + const auto fimReq = QJsonObject{ + {"keep_alive", {}}, + {"model", {}}, + {"stream", {}}, + {"prompt", {}}, + {"suffix", {}}, + {"system", {}}, + {"options", + QJsonObject{ + {"temperature", {}}, + {"top_p", {}}, + {"top_k", {}}, + {"num_predict", {}}, + {"frequency_penalty", {}}, + {"presence_penalty", {}}}}}; + + const auto messageReq = QJsonObject{ + {"keep_alive", {}}, + {"model", {}}, + {"stream", {}}, + {"messages", QJsonArray{{QJsonObject{{"role", {}}, {"content", {}}}}}}, + {"options", + QJsonObject{ + {"temperature", {}}, + {"top_p", {}}, + {"top_k", {}}, + {"num_predict", {}}, + {"frequency_penalty", {}}, + {"presence_penalty", {}}}}}; + + return LLMCore::ValidationUtils::validateRequestFields( + request, type == LLMCore::TemplateType::Fim ? fimReq : messageReq); +}; + } // namespace QodeAssist::Providers diff --git a/providers/OllamaProvider.hpp b/providers/OllamaProvider.hpp index 2fe20e0..41c2768 100644 --- a/providers/OllamaProvider.hpp +++ b/providers/OllamaProvider.hpp @@ -36,6 +36,7 @@ public: void prepareRequest(QJsonObject &request, LLMCore::RequestType type) override; bool handleResponse(QNetworkReply *reply, QString &accumulatedResponse) override; QList getInstalledModels(const QString &url) override; + QList validateRequest(const QJsonObject &request, LLMCore::TemplateType type) override; }; } // namespace QodeAssist::Providers diff --git a/providers/OpenAICompatProvider.cpp b/providers/OpenAICompatProvider.cpp index 2dee0be..45ceda9 100644 --- a/providers/OpenAICompatProvider.cpp +++ b/providers/OpenAICompatProvider.cpp @@ -27,6 +27,7 @@ #include #include "llmcore/OpenAIMessage.hpp" +#include "llmcore/ValidationUtils.hpp" #include "logger/Logger.hpp" namespace QodeAssist::Providers { @@ -142,4 +143,22 @@ QList OpenAICompatProvider::getInstalledModels(const QString &url) return QStringList(); } +QList OpenAICompatProvider::validateRequest( + const QJsonObject &request, LLMCore::TemplateType type) +{ + const auto templateReq = QJsonObject{ + {"model", {}}, + {"messages", QJsonArray{{QJsonObject{{"role", {}}, {"content", {}}}}}}, + {"temperature", {}}, + {"max_tokens", {}}, + {"top_p", {}}, + {"top_k", {}}, + {"frequency_penalty", {}}, + {"presence_penalty", {}}, + {"stop", QJsonArray{}}, + {"stream", {}}}; + + return LLMCore::ValidationUtils::validateRequestFields(request, templateReq); +} + } // namespace QodeAssist::Providers diff --git a/providers/OpenAICompatProvider.hpp b/providers/OpenAICompatProvider.hpp index 0dfae36..202d7c3 100644 --- a/providers/OpenAICompatProvider.hpp +++ b/providers/OpenAICompatProvider.hpp @@ -36,6 +36,7 @@ public: void prepareRequest(QJsonObject &request, LLMCore::RequestType type) override; bool handleResponse(QNetworkReply *reply, QString &accumulatedResponse) override; QList getInstalledModels(const QString &url) override; + QList validateRequest(const QJsonObject &request, LLMCore::TemplateType type) override; }; } // namespace QodeAssist::Providers