feat: Improve OpenAI message handling

This commit is contained in:
Petr Mironychev 2024-11-26 11:43:51 +01:00
parent b475f15e3d
commit 56b5ea8e68

View File

@ -26,13 +26,15 @@
#include <QJsonObject>
#include <QNetworkReply>
#include "logger/Logger.hpp"
namespace QodeAssist::Providers {
OpenAICompatProvider::OpenAICompatProvider() {}
QString OpenAICompatProvider::name() const
{
return "OpenAI Compatible (experimental)";
return "OpenAI Compatible";
}
QString OpenAICompatProvider::url() const
@ -99,24 +101,41 @@ void OpenAICompatProvider::prepareRequest(QJsonObject &request, LLMCore::Request
bool OpenAICompatProvider::handleResponse(QNetworkReply *reply, QString &accumulatedResponse)
{
bool isComplete = false;
QString tempResponse = accumulatedResponse;
while (reply->canReadLine()) {
QByteArray line = reply->readLine().trimmed();
if (line.isEmpty()) {
continue;
}
if (line == "data: [DONE]") {
if (!line.startsWith("data:")) {
continue;
}
line = line.mid(6);
if (line == "[DONE]") {
isComplete = true;
break;
}
if (line.startsWith("data: ")) {
line = line.mid(6); // Remove "data: " prefix
}
QJsonDocument jsonResponse = QJsonDocument::fromJson(line);
if (jsonResponse.isNull()) {
qWarning() << "Invalid JSON response from LM Studio:" << line;
LOG_MESSAGE(
"Invalid JSON response from OpenAI compatible provider: " + QString::fromUtf8(line));
continue;
}
QJsonObject responseObj = jsonResponse.object();
if (responseObj.contains("error")) {
LOG_MESSAGE(
"OpenAI compatible provider error: "
+ QString::fromUtf8(QJsonDocument(responseObj).toJson(QJsonDocument::Indented)));
return false;
}
if (responseObj.contains("choices")) {
QJsonArray choices = responseObj["choices"].toArray();
if (!choices.isEmpty()) {
@ -124,16 +143,30 @@ bool OpenAICompatProvider::handleResponse(QNetworkReply *reply, QString &accumul
QJsonObject delta = choice["delta"].toObject();
if (delta.contains("content")) {
QString completion = delta["content"].toString();
accumulatedResponse += completion;
if (!completion.isEmpty()) {
tempResponse += completion;
}
}
if (choice["finish_reason"].toString() == "stop") {
QString finishReason = choice["finish_reason"].toString();
if (!finishReason.isNull() && finishReason == "stop") {
isComplete = true;
break;
}
}
}
if (responseObj.contains("usage")) {
QJsonObject usage = responseObj["usage"].toObject();
LOG_MESSAGE(QString("Token usage - Prompt: %1, Completion: %2, Total: %3")
.arg(usage["prompt_tokens"].toInt())
.arg(usage["completion_tokens"].toInt())
.arg(usage["total_tokens"].toInt()));
}
}
if (!tempResponse.isEmpty()) {
accumulatedResponse = tempResponse;
}
return isComplete;
}