feat(S1): iterative conversation for multi-round prompt refinement
- Conversation model: store dialog context (max 10 rounds), JSON messages - POST /api/prompt/continue: append round, build LLM context from history - GET/DELETE /api/conversation/🆔 retrieve or clear conversation - Vue: refine input card below result, round counter, reset button - Vue: continuePrompt API with conversation_id tracking Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -71,6 +71,11 @@ def create_app(config_class=None):
|
||||
from src.flask_prompt_master.routes.history_routes import history_bp
|
||||
app.register_blueprint(history_bp)
|
||||
|
||||
# 迭代对话上下文(多轮优化)
|
||||
from src.flask_prompt_master.models import conversation # noqa: F401
|
||||
from src.flask_prompt_master.routes.conversation_routes import conversation_bp
|
||||
app.register_blueprint(conversation_bp)
|
||||
|
||||
# 提示词结构化质量评价(多段文本 + 模型 JSON 评价 + 历史)
|
||||
from src.flask_prompt_master.models import prompt_quality_models # noqa: F401
|
||||
from src.flask_prompt_master.routes.prompt_quality_routes import prompt_quality_bp
|
||||
|
||||
62
src/flask_prompt_master/models/conversation.py
Normal file
62
src/flask_prompt_master/models/conversation.py
Normal file
@@ -0,0 +1,62 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
对话上下文模型:支持多轮迭代优化提示词
|
||||
"""
|
||||
from datetime import datetime
|
||||
from src.flask_prompt_master import db
|
||||
|
||||
|
||||
class Conversation(db.Model):
|
||||
__tablename__ = 'conversation'
|
||||
|
||||
id = db.Column(db.Integer, primary_key=True, autoincrement=True)
|
||||
user_id = db.Column(db.Integer, nullable=False, index=True, comment='用户ID')
|
||||
scene_type = db.Column(db.String(30), default='prompt', comment='场景类型: prompt/meal/poetry/report')
|
||||
context_messages = db.Column(db.JSON, nullable=False, default=list, comment='对话上下文 [{role, content}]')
|
||||
created_at = db.Column(db.DateTime, default=datetime.utcnow)
|
||||
updated_at = db.Column(db.DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
|
||||
MAX_ROUNDS = 10 # 最多保留最近 N 轮对话
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
'id': self.id,
|
||||
'user_id': self.user_id,
|
||||
'scene_type': self.scene_type,
|
||||
'context_messages': self.context_messages or [],
|
||||
'created_at': self.created_at.isoformat() if self.created_at else None,
|
||||
'updated_at': self.updated_at.isoformat() if self.updated_at else None,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def append_round(cls, conversation_id, user_id, scene_type, role, content):
|
||||
"""追加一轮对话,若不存在则新建"""
|
||||
conv = None
|
||||
if conversation_id:
|
||||
conv = cls.query.filter_by(id=conversation_id, user_id=user_id).first()
|
||||
if conv is None:
|
||||
conv = cls(user_id=user_id, scene_type=scene_type, context_messages=[])
|
||||
db.session.add(conv)
|
||||
|
||||
messages = list(conv.context_messages or [])
|
||||
messages.append({'role': role, 'content': content})
|
||||
|
||||
# 保留最近 MAX_ROUNDS 轮
|
||||
if len(messages) > cls.MAX_ROUNDS:
|
||||
messages = messages[-cls.MAX_ROUNDS:]
|
||||
|
||||
conv.context_messages = messages
|
||||
conv.updated_at = datetime.utcnow()
|
||||
db.session.flush()
|
||||
return conv
|
||||
|
||||
@classmethod
|
||||
def build_llm_context(cls, conversation_id, user_id, system_prompt, new_user_msg):
|
||||
"""构建 LLM 调用的 messages: system + history + new_user"""
|
||||
messages = [{'role': 'system', 'content': system_prompt}]
|
||||
if conversation_id:
|
||||
conv = cls.query.filter_by(id=conversation_id, user_id=user_id).first()
|
||||
if conv and conv.context_messages:
|
||||
messages.extend(conv.context_messages)
|
||||
messages.append({'role': 'user', 'content': new_user_msg})
|
||||
return messages
|
||||
154
src/flask_prompt_master/routes/conversation_routes.py
Normal file
154
src/flask_prompt_master/routes/conversation_routes.py
Normal file
@@ -0,0 +1,154 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
迭代对话路由:支持基于历史上下文的多轮提示词优化
|
||||
"""
|
||||
from flask import Blueprint, request, jsonify, current_app
|
||||
from openai import OpenAI
|
||||
import os
|
||||
import logging
|
||||
from src.flask_prompt_master import db
|
||||
from src.flask_prompt_master.user_context import get_current_user_id
|
||||
from src.flask_prompt_master.models.models import Prompt, PromptTemplate
|
||||
from src.flask_prompt_master.models.conversation import Conversation
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
conversation_bp = Blueprint('conversation', __name__)
|
||||
|
||||
_client = OpenAI(
|
||||
api_key=os.environ.get('LLM_API_KEY') or '',
|
||||
base_url=os.environ.get('LLM_API_URL') or 'https://api.deepseek.com/v1',
|
||||
)
|
||||
|
||||
|
||||
@conversation_bp.route('/api/prompt/continue', methods=['POST'])
|
||||
def continue_prompt():
|
||||
"""基于已有结果继续优化提示词(多轮对话)"""
|
||||
data = request.get_json(silent=True) or {}
|
||||
conversation_id = data.get('conversation_id')
|
||||
previous_result = (data.get('previous_result') or '').strip()
|
||||
refine_instruction = (data.get('refine_instruction') or '').strip()
|
||||
template_id = data.get('template_id')
|
||||
|
||||
if not previous_result:
|
||||
return jsonify({'success': False, 'message': '缺少上一轮结果'}), 400
|
||||
if not refine_instruction:
|
||||
return jsonify({'success': False, 'message': '请输入优化指令'}), 400
|
||||
|
||||
if template_id is not None and template_id != '':
|
||||
try:
|
||||
template_id = int(template_id)
|
||||
except (TypeError, ValueError):
|
||||
template_id = None
|
||||
|
||||
user_id = get_current_user_id()
|
||||
|
||||
# 获取 system_prompt
|
||||
system_prompt = _get_system_prompt(template_id)
|
||||
|
||||
try:
|
||||
# 记录用户的新一轮指令到对话历史
|
||||
conv = Conversation.append_round(
|
||||
conversation_id=conversation_id,
|
||||
user_id=user_id,
|
||||
scene_type='prompt',
|
||||
role='user',
|
||||
content=f"上一轮结果:\n{previous_result}\n\n优化指令:{refine_instruction}",
|
||||
)
|
||||
|
||||
# 构建完整上下文
|
||||
current_app.logger.info(
|
||||
f"迭代对话 conversation_id={conv.id} user_id={user_id} "
|
||||
f"rounds={len(conv.context_messages)} instruction={refine_instruction[:80]}"
|
||||
)
|
||||
|
||||
# 调用 LLM
|
||||
response = _client.chat.completions.create(
|
||||
model='deepseek-chat',
|
||||
messages=Conversation.build_llm_context(
|
||||
conv.id, user_id, system_prompt,
|
||||
refine_instruction,
|
||||
),
|
||||
temperature=0.7,
|
||||
max_tokens=_get_max_tokens(template_id),
|
||||
timeout=60,
|
||||
)
|
||||
|
||||
generated_text = response.choices[0].message.content.strip()
|
||||
|
||||
# 记录助手回复到对话历史
|
||||
Conversation.append_round(
|
||||
conversation_id=conv.id,
|
||||
user_id=user_id,
|
||||
scene_type='prompt',
|
||||
role='assistant',
|
||||
content=generated_text,
|
||||
)
|
||||
|
||||
# 保存到 Prompt 表
|
||||
prompt = Prompt(
|
||||
input_text=refine_instruction,
|
||||
generated_text=generated_text,
|
||||
user_id=user_id,
|
||||
)
|
||||
db.session.add(prompt)
|
||||
db.session.commit()
|
||||
|
||||
return jsonify({
|
||||
'success': True,
|
||||
'conversation_id': conv.id,
|
||||
'generated_text': generated_text,
|
||||
'rounds': len(conv.context_messages) // 2,
|
||||
'prompt': {
|
||||
'id': prompt.id,
|
||||
'input_text': refine_instruction,
|
||||
'generated_text': generated_text,
|
||||
},
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
current_app.logger.error(f'continue_prompt error: {e}')
|
||||
db.session.rollback()
|
||||
return jsonify({'success': False, 'message': f'优化失败: {str(e)}'}), 500
|
||||
|
||||
|
||||
@conversation_bp.route('/api/conversation/<int:conversation_id>', methods=['GET'])
|
||||
def get_conversation(conversation_id):
|
||||
"""获取对话历史"""
|
||||
user_id = get_current_user_id()
|
||||
conv = Conversation.query.filter_by(id=conversation_id, user_id=user_id).first()
|
||||
if not conv:
|
||||
return jsonify({'success': False, 'message': '对话不存在'}), 404
|
||||
return jsonify({'success': True, 'conversation': conv.to_dict()})
|
||||
|
||||
|
||||
@conversation_bp.route('/api/conversation/<int:conversation_id>', methods=['DELETE'])
|
||||
def clear_conversation(conversation_id):
|
||||
"""清除对话上下文(开始新对话)"""
|
||||
user_id = get_current_user_id()
|
||||
conv = Conversation.query.filter_by(id=conversation_id, user_id=user_id).first()
|
||||
if not conv:
|
||||
return jsonify({'success': False, 'message': '对话不存在'}), 404
|
||||
db.session.delete(conv)
|
||||
db.session.commit()
|
||||
return jsonify({'success': True, 'message': '对话已清除'})
|
||||
|
||||
|
||||
def _get_system_prompt(template_id):
|
||||
"""获取系统提示词"""
|
||||
if template_id:
|
||||
template = PromptTemplate.query.get(template_id)
|
||||
if template:
|
||||
return template.system_prompt
|
||||
default = PromptTemplate.query.filter_by(is_default=True).first()
|
||||
if default:
|
||||
return default.system_prompt
|
||||
return "你是一个专业的提示词工程师,请直接返回优化后的提示词,不要添加任何解释。"
|
||||
|
||||
|
||||
def _get_max_tokens(template_id):
|
||||
"""获取 max_tokens"""
|
||||
if template_id:
|
||||
template = PromptTemplate.query.get(template_id)
|
||||
if template and template.max_tokens:
|
||||
return template.max_tokens
|
||||
return 500
|
||||
@@ -17,3 +17,21 @@ export function fetchTemplatesByCategory(category: string) {
|
||||
export function generatePrompt(body: { input_text: string; template_id: number | null; max_tokens?: number }) {
|
||||
return client.post<GeneratePromptResponse>('/api/prompt/generate', body).then((r) => r.data)
|
||||
}
|
||||
|
||||
export interface ContinuePromptResponse {
|
||||
success: boolean
|
||||
message?: string
|
||||
conversation_id: number
|
||||
generated_text: string
|
||||
rounds: number
|
||||
prompt?: { id: number; input_text: string; generated_text: string }
|
||||
}
|
||||
|
||||
export function continuePrompt(body: {
|
||||
conversation_id?: number | null
|
||||
previous_result: string
|
||||
refine_instruction: string
|
||||
template_id?: number | null
|
||||
}) {
|
||||
return client.post<ContinuePromptResponse>('/api/prompt/continue', body).then((r) => r.data)
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
<script setup lang="ts">
|
||||
import { computed, onMounted, ref, watch } from 'vue'
|
||||
import { useRouter } from 'vue-router'
|
||||
import { fetchGenerateMeta, fetchTemplatesByCategory, generatePrompt } from '@/api/modules/prompt'
|
||||
import { fetchGenerateMeta, fetchTemplatesByCategory, generatePrompt, continuePrompt } from '@/api/modules/prompt'
|
||||
import { quickAddFavorite } from '@/api/modules/favorite'
|
||||
import type { PromptTemplateItem } from '@/api/types/template'
|
||||
|
||||
@@ -30,6 +30,10 @@ const showAdvanced = ref(false)
|
||||
const maxTokensOverride = ref<number | undefined>(undefined)
|
||||
|
||||
const result = ref<{ id: number; input_text: string; generated_text: string } | null>(null)
|
||||
const conversationId = ref<number | null>(null)
|
||||
const rounds = ref(0)
|
||||
const refineInstruction = ref('')
|
||||
const refining = ref(false)
|
||||
|
||||
const router = useRouter()
|
||||
|
||||
@@ -213,6 +217,8 @@ async function onSubmit() {
|
||||
}
|
||||
if (res.prompt) {
|
||||
result.value = res.prompt
|
||||
rounds.value = res.rounds || 0
|
||||
if (res.conversation_id) conversationId.value = res.conversation_id
|
||||
ElMessage.success('生成成功')
|
||||
}
|
||||
} catch (e) {
|
||||
@@ -234,6 +240,45 @@ async function copyResult() {
|
||||
}
|
||||
}
|
||||
|
||||
async function onRefine() {
|
||||
const instruction = refineInstruction.value.trim()
|
||||
if (!instruction) {
|
||||
ElMessage.warning('请输入优化指令')
|
||||
return
|
||||
}
|
||||
if (!result.value) return
|
||||
refining.value = true
|
||||
try {
|
||||
const res = await continuePrompt({
|
||||
conversation_id: conversationId.value,
|
||||
previous_result: result.value.generated_text,
|
||||
refine_instruction: instruction,
|
||||
template_id: selectedTemplateId.value,
|
||||
})
|
||||
if (!res.success) {
|
||||
ElMessage.error(res.message || '优化失败')
|
||||
return
|
||||
}
|
||||
result.value = res.prompt ?? { id: 0, input_text: instruction, generated_text: res.generated_text }
|
||||
conversationId.value = res.conversation_id
|
||||
rounds.value = res.rounds
|
||||
refineInstruction.value = ''
|
||||
ElMessage.success(`已优化 (第 ${res.rounds} 轮)`)
|
||||
} catch {
|
||||
ElMessage.error('优化请求失败')
|
||||
} finally {
|
||||
refining.value = false
|
||||
}
|
||||
}
|
||||
|
||||
function resetConversation() {
|
||||
conversationId.value = null
|
||||
rounds.value = 0
|
||||
refineInstruction.value = ''
|
||||
result.value = null
|
||||
ElMessage.info('已开始新对话')
|
||||
}
|
||||
|
||||
async function addToFavorites() {
|
||||
if (!result.value || selectedTemplateId.value == null) return
|
||||
const tpl = selectedTemplate.value
|
||||
@@ -399,6 +444,29 @@ onMounted(async () => {
|
||||
</template>
|
||||
<pre class="result-pre">{{ result.generated_text }}</pre>
|
||||
</el-card>
|
||||
|
||||
<el-card v-if="result" shadow="never" class="section refine-card">
|
||||
<template #header>
|
||||
<span>继续优化</span>
|
||||
<el-tag v-if="rounds > 0" size="small" type="info" style="margin-left: 8px">第 {{ rounds }} 轮</el-tag>
|
||||
</template>
|
||||
<el-input
|
||||
v-model="refineInstruction"
|
||||
type="textarea"
|
||||
:rows="3"
|
||||
maxlength="500"
|
||||
show-word-limit
|
||||
placeholder="例如:精简到200字以内 / 增加技术约束 / 改为英文 / 面向非技术读者"
|
||||
/>
|
||||
<div class="actions">
|
||||
<el-space>
|
||||
<el-button type="primary" :loading="refining" @click="onRefine">
|
||||
提交优化
|
||||
</el-button>
|
||||
<el-button text type="info" @click="resetConversation">开始新对话</el-button>
|
||||
</el-space>
|
||||
</div>
|
||||
</el-card>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
|
||||
Reference in New Issue
Block a user