feat: Add request validator

This commit is contained in:
Petr Mironychev 2024-12-15 02:08:35 +01:00
parent 10e8b16caf
commit 7376a11a05
12 changed files with 194 additions and 1 deletions

View File

@ -131,6 +131,13 @@ void ClientInterface::sendMessage(const QString &message, bool includeCurrentFil
QJsonObject request;
request["id"] = QUuid::createUuid().toString();
auto errors = config.provider->validateRequest(config.providerRequest, promptTemplate->type());
if (!errors.isEmpty()) {
LOG_MESSAGE("Validate errors for chat request:");
LOG_MESSAGES(errors);
return;
}
m_requestHandler->sendLLMRequest(config, request);
}

View File

@ -195,6 +195,12 @@ void LLMClientInterface::handleCompletion(const QJsonObject &request)
config.promptTemplate->prepareRequest(config.providerRequest, updatedContext);
config.provider->prepareRequest(config.providerRequest, LLMCore::RequestType::Fim);
auto errors = config.provider->validateRequest(config.providerRequest, promptTemplate->type());
if (!errors.isEmpty()) {
LOG_MESSAGE("Validate errors for fim request:");
LOG_MESSAGES(errors);
return;
}
m_requestHandler.sendLLMRequest(config, request);
}

View File

@ -9,6 +9,7 @@ add_library(LLMCore STATIC
RequestHandler.hpp RequestHandler.cpp
OllamaMessage.hpp OllamaMessage.cpp
OpenAIMessage.hpp OpenAIMessage.cpp
ValidationUtils.hpp ValidationUtils.cpp
)
target_link_libraries(LLMCore

View File

@ -20,9 +20,11 @@
#pragma once
#include <QString>
#include "RequestType.hpp"
#include <utils/environment.h>
#include "PromptTemplate.hpp"
#include "RequestType.hpp"
class QNetworkReply;
class QJsonObject;
@ -42,6 +44,7 @@ public:
virtual void prepareRequest(QJsonObject &request, RequestType type) = 0;
virtual bool handleResponse(QNetworkReply *reply, QString &accumulatedResponse) = 0;
virtual QList<QString> getInstalledModels(const QString &url) = 0;
virtual QList<QString> validateRequest(const QJsonObject &request, TemplateType type) = 0;
};
} // namespace QodeAssist::LLMCore

View File

@ -0,0 +1,57 @@
/*
* Copyright (C) 2024 Petr Mironychev
*
* This file is part of QodeAssist.
*
* QodeAssist is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* QodeAssist is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with QodeAssist. If not, see <https://www.gnu.org/licenses/>.
*/
#include "ValidationUtils.hpp"
#include <QJsonArray>
namespace QodeAssist::LLMCore {
QStringList ValidationUtils::validateRequestFields(
const QJsonObject &request, const QJsonObject &templateObj)
{
QStringList errors;
validateFields(request, templateObj, errors);
validateNestedObjects(request, templateObj, errors);
return errors;
}
void ValidationUtils::validateFields(
const QJsonObject &request, const QJsonObject &templateObj, QStringList &errors)
{
for (auto it = request.begin(); it != request.end(); ++it) {
if (!templateObj.contains(it.key())) {
errors << QString("unknown field '%1'").arg(it.key());
}
}
}
void ValidationUtils::validateNestedObjects(
const QJsonObject &request, const QJsonObject &templateObj, QStringList &errors)
{
for (auto it = request.begin(); it != request.end(); ++it) {
if (templateObj.contains(it.key()) && it.value().isObject()
&& templateObj[it.key()].isObject()) {
validateFields(it.value().toObject(), templateObj[it.key()].toObject(), errors);
validateNestedObjects(it.value().toObject(), templateObj[it.key()].toObject(), errors);
}
}
}
} // namespace QodeAssist::LLMCore

View File

@ -0,0 +1,41 @@
/*
* Copyright (C) 2024 Petr Mironychev
*
* This file is part of QodeAssist.
*
* QodeAssist is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* QodeAssist is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with QodeAssist. If not, see <https://www.gnu.org/licenses/>.
*/
#pragma once
#include <QJsonObject>
#include <QStringList>
namespace QodeAssist::LLMCore {
class ValidationUtils
{
public:
static QStringList validateRequestFields(
const QJsonObject &request, const QJsonObject &templateObj);
private:
static void validateFields(
const QJsonObject &request, const QJsonObject &templateObj, QStringList &errors);
static void validateNestedObjects(
const QJsonObject &request, const QJsonObject &templateObj, QStringList &errors);
};
} // namespace QodeAssist::LLMCore

View File

@ -26,6 +26,7 @@
#include <QNetworkReply>
#include "llmcore/OpenAIMessage.hpp"
#include "llmcore/ValidationUtils.hpp"
#include "logger/Logger.hpp"
#include "settings/ChatAssistantSettings.hpp"
#include "settings/CodeCompletionSettings.hpp"
@ -169,4 +170,22 @@ QList<QString> LMStudioProvider::getInstalledModels(const QString &url)
return models;
}
QList<QString> LMStudioProvider::validateRequest(
const QJsonObject &request, LLMCore::TemplateType type)
{
const auto templateReq = QJsonObject{
{"model", {}},
{"messages", QJsonArray{{QJsonObject{{"role", {}}, {"content", {}}}}}},
{"temperature", {}},
{"max_tokens", {}},
{"top_p", {}},
{"top_k", {}},
{"frequency_penalty", {}},
{"presence_penalty", {}},
{"stop", QJsonArray{}},
{"stream", {}}};
return LLMCore::ValidationUtils::validateRequestFields(request, templateReq);
}
} // namespace QodeAssist::Providers

View File

@ -36,6 +36,7 @@ public:
void prepareRequest(QJsonObject &request, LLMCore::RequestType type) override;
bool handleResponse(QNetworkReply *reply, QString &accumulatedResponse) override;
QList<QString> getInstalledModels(const QString &url) override;
QList<QString> validateRequest(const QJsonObject &request, LLMCore::TemplateType type) override;
};
} // namespace QodeAssist::Providers

View File

@ -26,6 +26,7 @@
#include <QtCore/qeventloop.h>
#include "llmcore/OllamaMessage.hpp"
#include "llmcore/ValidationUtils.hpp"
#include "logger/Logger.hpp"
#include "settings/ChatAssistantSettings.hpp"
#include "settings/CodeCompletionSettings.hpp"
@ -137,4 +138,40 @@ QList<QString> OllamaProvider::getInstalledModels(const QString &url)
return models;
}
QList<QString> OllamaProvider::validateRequest(const QJsonObject &request, LLMCore::TemplateType type)
{
const auto fimReq = QJsonObject{
{"keep_alive", {}},
{"model", {}},
{"stream", {}},
{"prompt", {}},
{"suffix", {}},
{"system", {}},
{"options",
QJsonObject{
{"temperature", {}},
{"top_p", {}},
{"top_k", {}},
{"num_predict", {}},
{"frequency_penalty", {}},
{"presence_penalty", {}}}}};
const auto messageReq = QJsonObject{
{"keep_alive", {}},
{"model", {}},
{"stream", {}},
{"messages", QJsonArray{{QJsonObject{{"role", {}}, {"content", {}}}}}},
{"options",
QJsonObject{
{"temperature", {}},
{"top_p", {}},
{"top_k", {}},
{"num_predict", {}},
{"frequency_penalty", {}},
{"presence_penalty", {}}}}};
return LLMCore::ValidationUtils::validateRequestFields(
request, type == LLMCore::TemplateType::Fim ? fimReq : messageReq);
};
} // namespace QodeAssist::Providers

View File

@ -36,6 +36,7 @@ public:
void prepareRequest(QJsonObject &request, LLMCore::RequestType type) override;
bool handleResponse(QNetworkReply *reply, QString &accumulatedResponse) override;
QList<QString> getInstalledModels(const QString &url) override;
QList<QString> validateRequest(const QJsonObject &request, LLMCore::TemplateType type) override;
};
} // namespace QodeAssist::Providers

View File

@ -27,6 +27,7 @@
#include <QNetworkReply>
#include "llmcore/OpenAIMessage.hpp"
#include "llmcore/ValidationUtils.hpp"
#include "logger/Logger.hpp"
namespace QodeAssist::Providers {
@ -142,4 +143,22 @@ QList<QString> OpenAICompatProvider::getInstalledModels(const QString &url)
return QStringList();
}
QList<QString> OpenAICompatProvider::validateRequest(
const QJsonObject &request, LLMCore::TemplateType type)
{
const auto templateReq = QJsonObject{
{"model", {}},
{"messages", QJsonArray{{QJsonObject{{"role", {}}, {"content", {}}}}}},
{"temperature", {}},
{"max_tokens", {}},
{"top_p", {}},
{"top_k", {}},
{"frequency_penalty", {}},
{"presence_penalty", {}},
{"stop", QJsonArray{}},
{"stream", {}}};
return LLMCore::ValidationUtils::validateRequestFields(request, templateReq);
}
} // namespace QodeAssist::Providers

View File

@ -36,6 +36,7 @@ public:
void prepareRequest(QJsonObject &request, LLMCore::RequestType type) override;
bool handleResponse(QNetworkReply *reply, QString &accumulatedResponse) override;
QList<QString> getInstalledModels(const QString &url) override;
QList<QString> validateRequest(const QJsonObject &request, LLMCore::TemplateType type) override;
};
} // namespace QodeAssist::Providers