feat: Improve OpenAI message handling

This commit is contained in:
Petr Mironychev 2024-11-26 11:43:51 +01:00
parent b475f15e3d
commit 56b5ea8e68

View File

@ -26,13 +26,15 @@
#include <QJsonObject> #include <QJsonObject>
#include <QNetworkReply> #include <QNetworkReply>
#include "logger/Logger.hpp"
namespace QodeAssist::Providers { namespace QodeAssist::Providers {
OpenAICompatProvider::OpenAICompatProvider() {} OpenAICompatProvider::OpenAICompatProvider() {}
QString OpenAICompatProvider::name() const QString OpenAICompatProvider::name() const
{ {
return "OpenAI Compatible (experimental)"; return "OpenAI Compatible";
} }
QString OpenAICompatProvider::url() const QString OpenAICompatProvider::url() const
@ -99,24 +101,41 @@ void OpenAICompatProvider::prepareRequest(QJsonObject &request, LLMCore::Request
bool OpenAICompatProvider::handleResponse(QNetworkReply *reply, QString &accumulatedResponse) bool OpenAICompatProvider::handleResponse(QNetworkReply *reply, QString &accumulatedResponse)
{ {
bool isComplete = false; bool isComplete = false;
QString tempResponse = accumulatedResponse;
while (reply->canReadLine()) { while (reply->canReadLine()) {
QByteArray line = reply->readLine().trimmed(); QByteArray line = reply->readLine().trimmed();
if (line.isEmpty()) { if (line.isEmpty()) {
continue; continue;
} }
if (line == "data: [DONE]") {
if (!line.startsWith("data:")) {
continue;
}
line = line.mid(6);
if (line == "[DONE]") {
isComplete = true; isComplete = true;
break; break;
} }
if (line.startsWith("data: ")) {
line = line.mid(6); // Remove "data: " prefix
}
QJsonDocument jsonResponse = QJsonDocument::fromJson(line); QJsonDocument jsonResponse = QJsonDocument::fromJson(line);
if (jsonResponse.isNull()) { if (jsonResponse.isNull()) {
qWarning() << "Invalid JSON response from LM Studio:" << line; LOG_MESSAGE(
"Invalid JSON response from OpenAI compatible provider: " + QString::fromUtf8(line));
continue; continue;
} }
QJsonObject responseObj = jsonResponse.object(); QJsonObject responseObj = jsonResponse.object();
if (responseObj.contains("error")) {
LOG_MESSAGE(
"OpenAI compatible provider error: "
+ QString::fromUtf8(QJsonDocument(responseObj).toJson(QJsonDocument::Indented)));
return false;
}
if (responseObj.contains("choices")) { if (responseObj.contains("choices")) {
QJsonArray choices = responseObj["choices"].toArray(); QJsonArray choices = responseObj["choices"].toArray();
if (!choices.isEmpty()) { if (!choices.isEmpty()) {
@ -124,16 +143,30 @@ bool OpenAICompatProvider::handleResponse(QNetworkReply *reply, QString &accumul
QJsonObject delta = choice["delta"].toObject(); QJsonObject delta = choice["delta"].toObject();
if (delta.contains("content")) { if (delta.contains("content")) {
QString completion = delta["content"].toString(); QString completion = delta["content"].toString();
if (!completion.isEmpty()) {
accumulatedResponse += completion; tempResponse += completion;
} }
if (choice["finish_reason"].toString() == "stop") { }
QString finishReason = choice["finish_reason"].toString();
if (!finishReason.isNull() && finishReason == "stop") {
isComplete = true; isComplete = true;
break;
} }
} }
} }
if (responseObj.contains("usage")) {
QJsonObject usage = responseObj["usage"].toObject();
LOG_MESSAGE(QString("Token usage - Prompt: %1, Completion: %2, Total: %3")
.arg(usage["prompt_tokens"].toInt())
.arg(usage["completion_tokens"].toInt())
.arg(usage["total_tokens"].toInt()));
} }
}
if (!tempResponse.isEmpty()) {
accumulatedResponse = tempResponse;
}
return isComplete; return isComplete;
} }