From 80eda8c1678f511a2932a7f891f9e62bc473039f Mon Sep 17 00:00:00 2001 From: Petr Mironychev <9195189+Palm1r@users.noreply.github.com> Date: Wed, 16 Oct 2024 22:51:34 +0200 Subject: [PATCH] Add stream text to chat --- chatview/ChatModel.cpp | 33 ++++++++++++++++++++++----------- chatview/ChatModel.hpp | 3 ++- chatview/ClientInterface.cpp | 19 +++++++++---------- chatview/ClientInterface.hpp | 3 +-- llmcore/RequestHandler.cpp | 22 +++++++++------------- 5 files changed, 43 insertions(+), 37 deletions(-) diff --git a/chatview/ChatModel.cpp b/chatview/ChatModel.cpp index b669255..ce63a55 100644 --- a/chatview/ChatModel.cpp +++ b/chatview/ChatModel.cpp @@ -68,6 +68,28 @@ QHash ChatModel::roleNames() const return roles; } +void ChatModel::addMessage(const QString &content, ChatRole role, const QString &id) +{ + int tokenCount = estimateTokenCount(content); + + if (!m_messages.isEmpty() && !id.isEmpty() && m_messages.last().id == id) { + Message &lastMessage = m_messages.last(); + int oldTokenCount = lastMessage.tokenCount; + lastMessage.content = content; + lastMessage.tokenCount = tokenCount; + m_totalTokens += (tokenCount - oldTokenCount); + emit dataChanged(index(m_messages.size() - 1), index(m_messages.size() - 1)); + } else { + beginInsertRows(QModelIndex(), m_messages.size(), m_messages.size()); + m_messages.append({role, content, tokenCount, id}); + m_totalTokens += tokenCount; + endInsertRows(); + } + + trim(); + emit totalTokensChanged(); +} + QVector ChatModel::getChatHistory() const { return m_messages; @@ -92,17 +114,6 @@ int ChatModel::estimateTokenCount(const QString &text) const return text.length() / 4; } -void ChatModel::addMessage(const QString &content, ChatRole role) -{ - int tokenCount = estimateTokenCount(content); - beginInsertRows(QModelIndex(), m_messages.size(), m_messages.size()); - m_messages.append({role, content, tokenCount}); - m_totalTokens += tokenCount; - endInsertRows(); - trim(); - emit totalTokensChanged(); -} - void ChatModel::clear() { beginResetModel(); diff --git a/chatview/ChatModel.hpp b/chatview/ChatModel.hpp index b13630b..df6c8b1 100644 --- a/chatview/ChatModel.hpp +++ b/chatview/ChatModel.hpp @@ -46,6 +46,7 @@ public: ChatRole role; QString content; int tokenCount; + QString id; }; explicit ChatModel(QObject *parent = nullptr); @@ -54,7 +55,7 @@ public: QVariant data(const QModelIndex &index, int role = Qt::DisplayRole) const override; QHash roleNames() const override; - Q_INVOKABLE void addMessage(const QString &content, ChatRole role); + Q_INVOKABLE void addMessage(const QString &content, ChatRole role, const QString &id); Q_INVOKABLE void clear(); Q_INVOKABLE QList processMessageContent(const QString &content) const; diff --git a/chatview/ClientInterface.cpp b/chatview/ClientInterface.cpp index 4b0b114..e20fa55 100644 --- a/chatview/ClientInterface.cpp +++ b/chatview/ClientInterface.cpp @@ -38,8 +38,8 @@ ClientInterface::ClientInterface(ChatModel *chatModel, QObject *parent) connect(m_requestHandler, &LLMCore::RequestHandler::completionReceived, this, - [this](const QString &completion, const QJsonObject &, bool isComplete) { - handleLLMResponse(completion, isComplete); + [this](const QString &completion, const QJsonObject &request, bool isComplete) { + handleLLMResponse(completion, request, isComplete); }); connect(m_requestHandler, @@ -90,7 +90,7 @@ void ClientInterface::sendMessage(const QString &message) request["id"] = QUuid::createUuid().toString(); m_accumulatedResponse.clear(); - m_chatModel->addMessage(message, ChatModel::ChatRole::User); + m_chatModel->addMessage(message, ChatModel::ChatRole::User, ""); m_requestHandler->sendLLMRequest(config, request); } @@ -101,16 +101,15 @@ void ClientInterface::clearMessages() LOG_MESSAGE("Chat history cleared"); } -void ClientInterface::handleLLMResponse(const QString &response, bool isComplete) +void ClientInterface::handleLLMResponse(const QString &response, + const QJsonObject &request, + bool isComplete) { - m_accumulatedResponse += response; + QString messageId = request["id"].toString(); + m_chatModel->addMessage(response.trimmed(), ChatModel::ChatRole::Assistant, messageId); if (isComplete) { - LOG_MESSAGE("Message completed. Final response: " + m_accumulatedResponse); - emit messageReceived(m_accumulatedResponse.trimmed()); - - m_chatModel->addMessage(m_accumulatedResponse.trimmed(), ChatModel::ChatRole::Assistant); - m_accumulatedResponse.clear(); + LOG_MESSAGE("Message completed. Final response for message " + messageId + ": " + response); } } diff --git a/chatview/ClientInterface.hpp b/chatview/ClientInterface.hpp index 24863ce..bd24f4b 100644 --- a/chatview/ClientInterface.hpp +++ b/chatview/ClientInterface.hpp @@ -40,11 +40,10 @@ public: void clearMessages(); signals: - void messageReceived(const QString &message); void errorOccurred(const QString &error); private: - void handleLLMResponse(const QString &response, bool isComplete); + void handleLLMResponse(const QString &response, const QJsonObject &request, bool isComplete); LLMCore::RequestHandler *m_requestHandler; QString m_accumulatedResponse; diff --git a/llmcore/RequestHandler.cpp b/llmcore/RequestHandler.cpp index 99e19ca..1a7e5c6 100644 --- a/llmcore/RequestHandler.cpp +++ b/llmcore/RequestHandler.cpp @@ -80,22 +80,18 @@ void RequestHandler::handleLLMResponse(QNetworkReply *reply, && processSingleLineCompletion(reply, request, accumulatedResponse, config)) { return; } + + if (isComplete) { + auto cleanedCompletion = removeStopWords(accumulatedResponse, + config.promptTemplate->stopWords()); + emit completionReceived(cleanedCompletion, request, true); + } + } else if (config.requestType == RequestType::Chat) { + emit completionReceived(accumulatedResponse, request, isComplete); } - if (isComplete || reply->isFinished()) { - if (isComplete) { - if (config.requestType == RequestType::Fim) { - auto cleanedCompletion = removeStopWords(accumulatedResponse, - config.promptTemplate->stopWords()); - emit completionReceived(cleanedCompletion, request, true); - } else { - emit completionReceived(accumulatedResponse, request, true); - } - } else { - emit completionReceived(accumulatedResponse, request, false); - } + if (isComplete) m_accumulatedResponses.remove(reply); - } } bool RequestHandler::cancelRequest(const QString &id)