Fix systemPrompt and context working

This commit is contained in:
Petr Mironychev 2024-11-16 10:20:57 +01:00 committed by GitHub
parent 7af8fc2ddc
commit 5e813ba402
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 31 additions and 44 deletions

View File

@ -152,11 +152,11 @@ QList<MessagePart> ChatModel::processMessageContent(const QString &content) cons
return parts; return parts;
} }
QJsonArray ChatModel::prepareMessagesForRequest(LLMCore::ContextData context) const QJsonArray ChatModel::prepareMessagesForRequest(const QString &systemPrompt) const
{ {
QJsonArray messages; QJsonArray messages;
messages.append(QJsonObject{{"role", "system"}, {"content", context.systemPrompt}}); messages.append(QJsonObject{{"role", "system"}, {"content", systemPrompt}});
for (const auto &message : m_messages) { for (const auto &message : m_messages) {
QString role; QString role;

View File

@ -60,7 +60,7 @@ public:
Q_INVOKABLE QList<MessagePart> processMessageContent(const QString &content) const; Q_INVOKABLE QList<MessagePart> processMessageContent(const QString &content) const;
QVector<Message> getChatHistory() const; QVector<Message> getChatHistory() const;
QJsonArray prepareMessagesForRequest(LLMCore::ContextData context) const; QJsonArray prepareMessagesForRequest(const QString &systemPrompt) const;
int totalTokens() const; int totalTokens() const;
int tokensThreshold() const; int tokensThreshold() const;

View File

@ -81,24 +81,21 @@ void ClientInterface::sendMessage(const QString &message, bool includeCurrentFil
context.prefix = message; context.prefix = message;
context.suffix = ""; context.suffix = "";
QString systemPrompt = chatAssistantSettings.systemPrompt(); QString systemPrompt;
if (chatAssistantSettings.useSystemPrompt())
systemPrompt = chatAssistantSettings.systemPrompt();
if (includeCurrentFile) { if (includeCurrentFile) {
QString fileContext = getCurrentFileContext(); QString fileContext = getCurrentFileContext();
if (!fileContext.isEmpty()) { if (!fileContext.isEmpty()) {
context.systemPrompt = QString("%1\n\n%2").arg(systemPrompt, fileContext); systemPrompt = systemPrompt.append(fileContext);
LOG_MESSAGE("Using system prompt with file context");
} else {
context.systemPrompt = systemPrompt;
LOG_MESSAGE("Failed to get file context, using default system prompt");
} }
} else {
context.systemPrompt = systemPrompt;
} }
QJsonObject providerRequest; QJsonObject providerRequest;
providerRequest["model"] = Settings::generalSettings().caModel(); providerRequest["model"] = Settings::generalSettings().caModel();
providerRequest["stream"] = true; providerRequest["stream"] = true;
providerRequest["messages"] = m_chatModel->prepareMessagesForRequest(context); providerRequest["messages"] = m_chatModel->prepareMessagesForRequest(systemPrompt);
if (promptTemplate) if (promptTemplate)
promptTemplate->prepareRequest(providerRequest, context); promptTemplate->prepareRequest(providerRequest, context);

View File

@ -207,9 +207,15 @@ LLMCore::ContextData DocumentContextReader::prepareContext(int lineNumber, int c
{ {
QString contextBefore = getContextBefore(lineNumber, cursorPosition); QString contextBefore = getContextBefore(lineNumber, cursorPosition);
QString contextAfter = getContextAfter(lineNumber, cursorPosition); QString contextAfter = getContextAfter(lineNumber, cursorPosition);
QString instructions = getInstructions();
return {contextBefore, contextAfter, instructions}; QString fileContext;
if (Settings::codeCompletionSettings().useFilePathInContext())
fileContext += getLanguageAndFileInfo();
if (Settings::codeCompletionSettings().useProjectChangesCache())
fileContext += ChangesManager::instance().getRecentChangesContext(m_textDocument);
return {contextBefore, contextAfter, fileContext};
} }
QString DocumentContextReader::getContextBefore(int lineNumber, int cursorPosition) const QString DocumentContextReader::getContextBefore(int lineNumber, int cursorPosition) const
@ -239,17 +245,4 @@ QString DocumentContextReader::getContextAfter(int lineNumber, int cursorPositio
} }
} }
QString DocumentContextReader::getInstructions() const
{
QString instructions;
if (Settings::codeCompletionSettings().useFilePathInContext())
instructions += getLanguageAndFileInfo();
if (Settings::codeCompletionSettings().useProjectChangesCache())
instructions += ChangesManager::instance().getRecentChangesContext(m_textDocument);
return instructions;
}
} // namespace QodeAssist } // namespace QodeAssist

View File

@ -54,7 +54,6 @@ public:
private: private:
QString getContextBefore(int lineNumber, int cursorPosition) const; QString getContextBefore(int lineNumber, int cursorPosition) const;
QString getContextAfter(int lineNumber, int cursorPosition) const; QString getContextAfter(int lineNumber, int cursorPosition) const;
QString getInstructions() const;
private: private:
TextEditor::TextDocument *m_textDocument; TextEditor::TextDocument *m_textDocument;

View File

@ -169,8 +169,13 @@ void LLMClientInterface::handleCompletion(const QJsonObject &request)
QJsonArray::fromStringList(config.promptTemplate->stopWords())}}; QJsonArray::fromStringList(config.promptTemplate->stopWords())}};
config.multiLineCompletion = completeSettings.multiLineCompletion(); config.multiLineCompletion = completeSettings.multiLineCompletion();
QString systemPrompt;
if (completeSettings.useSystemPrompt()) if (completeSettings.useSystemPrompt())
config.providerRequest["system"] = completeSettings.systemPrompt(); systemPrompt.append(completeSettings.systemPrompt());
if (!updatedContext.fileContext.isEmpty())
systemPrompt.append(updatedContext.fileContext);
config.providerRequest["system"] = systemPrompt;
config.promptTemplate->prepareRequest(config.providerRequest, updatedContext); config.promptTemplate->prepareRequest(config.providerRequest, updatedContext);
config.provider->prepareRequest(config.providerRequest, LLMCore::RequestType::Fim); config.provider->prepareRequest(config.providerRequest, LLMCore::RequestType::Fim);

View File

@ -27,7 +27,7 @@ struct ContextData
{ {
QString prefix; QString prefix;
QString suffix; QString suffix;
QString systemPrompt; QString fileContext;
}; };
} // namespace QodeAssist::LLMCore } // namespace QodeAssist::LLMCore

View File

@ -148,7 +148,7 @@ CodeCompletionSettings::CodeCompletionSettings()
"and contextually appropriate code suggestions."); "and contextually appropriate code suggestions.");
useFilePathInContext.setSettingsKey(Constants::CC_USE_FILE_PATH_IN_CONTEXT); useFilePathInContext.setSettingsKey(Constants::CC_USE_FILE_PATH_IN_CONTEXT);
useFilePathInContext.setDefaultValue(false); useFilePathInContext.setDefaultValue(true);
useFilePathInContext.setLabelText(Tr::tr("Use File Path in Context")); useFilePathInContext.setLabelText(Tr::tr("Use File Path in Context"));
useProjectChangesCache.setSettingsKey(Constants::CC_USE_PROJECT_CHANGES_CACHE); useProjectChangesCache.setSettingsKey(Constants::CC_USE_PROJECT_CHANGES_CACHE);

View File

@ -28,7 +28,7 @@ class CodeLlamaFim : public LLMCore::PromptTemplate
public: public:
LLMCore::TemplateType type() const override { return LLMCore::TemplateType::Fim; } LLMCore::TemplateType type() const override { return LLMCore::TemplateType::Fim; }
QString name() const override { return "CodeLlama FIM"; } QString name() const override { return "CodeLlama FIM"; }
QString promptTemplate() const override { return "%1<PRE> %2 <SUF>%3 <MID>"; } QString promptTemplate() const override { return "<PRE> %1 <SUF>%2 <MID>"; }
QStringList stopWords() const override QStringList stopWords() const override
{ {
return QStringList() << "<EOT>" << "<PRE>" << "<SUF" << "<MID>"; return QStringList() << "<EOT>" << "<PRE>" << "<SUF" << "<MID>";
@ -36,9 +36,7 @@ public:
void prepareRequest(QJsonObject &request, const LLMCore::ContextData &context) const override void prepareRequest(QJsonObject &request, const LLMCore::ContextData &context) const override
{ {
QString formattedPrompt = promptTemplate().arg(context.systemPrompt, QString formattedPrompt = promptTemplate().arg(context.prefix, context.suffix);
context.prefix,
context.suffix);
request["prompt"] = formattedPrompt; request["prompt"] = formattedPrompt;
} }
}; };

View File

@ -62,7 +62,6 @@ private:
{ {
if (value.isString()) { if (value.isString()) {
QString str = value.toString(); QString str = value.toString();
str.replace("{{QODE_INSTRUCTIONS}}", context.systemPrompt);
str.replace("{{QODE_PREFIX}}", context.prefix); str.replace("{{QODE_PREFIX}}", context.prefix);
str.replace("{{QODE_SUFFIX}}", context.suffix); str.replace("{{QODE_SUFFIX}}", context.suffix);
return str; return str;

View File

@ -30,14 +30,12 @@ public:
QString name() const override { return "DeepSeekCoder FIM"; } QString name() const override { return "DeepSeekCoder FIM"; }
QString promptTemplate() const override QString promptTemplate() const override
{ {
return "%1<fim▁begin>%2<fim▁hole>%3<fim▁end>"; return "<fim▁begin>%1<fim▁hole>%2<fim▁end>";
} }
QStringList stopWords() const override { return QStringList(); } QStringList stopWords() const override { return QStringList(); }
void prepareRequest(QJsonObject &request, const LLMCore::ContextData &context) const override void prepareRequest(QJsonObject &request, const LLMCore::ContextData &context) const override
{ {
QString formattedPrompt = promptTemplate().arg(context.systemPrompt, QString formattedPrompt = promptTemplate().arg(context.prefix, context.suffix);
context.prefix,
context.suffix);
request["prompt"] = formattedPrompt; request["prompt"] = formattedPrompt;
} }
}; };

View File

@ -28,7 +28,7 @@ class StarCoder2Fim : public LLMCore::PromptTemplate
public: public:
LLMCore::TemplateType type() const override { return LLMCore::TemplateType::Fim; } LLMCore::TemplateType type() const override { return LLMCore::TemplateType::Fim; }
QString name() const override { return "StarCoder2 FIM"; } QString name() const override { return "StarCoder2 FIM"; }
QString promptTemplate() const override { return "%1<fim_prefix>%2<fim_suffix>%3<fim_middle>"; } QString promptTemplate() const override { return "<fim_prefix>%1<fim_suffix>%2<fim_middle>"; }
QStringList stopWords() const override QStringList stopWords() const override
{ {
return QStringList() << "<|endoftext|>" << "<file_sep>" << "<fim_prefix>" << "<fim_suffix>" return QStringList() << "<|endoftext|>" << "<file_sep>" << "<fim_prefix>" << "<fim_suffix>"
@ -36,9 +36,7 @@ public:
} }
void prepareRequest(QJsonObject &request, const LLMCore::ContextData &context) const override void prepareRequest(QJsonObject &request, const LLMCore::ContextData &context) const override
{ {
QString formattedPrompt = promptTemplate().arg(context.systemPrompt, QString formattedPrompt = promptTemplate().arg(context.prefix, context.suffix);
context.prefix,
context.suffix);
request["prompt"] = formattedPrompt; request["prompt"] = formattedPrompt;
} }
}; };