feat: persistent chat message storage + Android pull-to-load history
Some checks failed
CI/CD Pipeline / Backend — Lint & Test (push) Has been cancelled
CI/CD Pipeline / Frontend — Lint & Build (push) Has been cancelled
CI/CD Pipeline / Docker — Build Check (push) Has been cancelled

Backend:
- Add ChatMessage model + Alembic migration 024
- Add on_message callback to AgentRuntime for persisting messages during SSE streaming
- Plumb session_id from ChatRequest to AgentContext in all 4 chat endpoints
- Add GET /agent-chat/{id}/sessions and /sessions/{sid}/messages with cursor pagination

Android:
- Add DTOs/ApiService/MessageDao for server-side chat history
- ChatRepository: fetchOlderMessages (API + Room cache), offline fallback
- ChatViewModel: loadMoreHistory with isLoadingMore/hasMoreMessages state
- ChatScreen: scroll-to-top detection + top loading indicator

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-06-30 00:07:26 +08:00
parent 569e3ab7df
commit a06082480a
12 changed files with 705 additions and 18 deletions

View File

@@ -23,4 +23,10 @@ interface MessageDao {
@Query("DELETE FROM messages")
suspend fun deleteAll()
@Query("SELECT * FROM messages WHERE conversationId = :conversationId AND createdAt < :beforeTimestamp ORDER BY createdAt DESC LIMIT :limit")
suspend fun getMessagesBeforeTimestamp(conversationId: String, beforeTimestamp: Long, limit: Int): List<MessageEntity>
@Query("SELECT COUNT(*) FROM messages WHERE conversationId = :conversationId")
suspend fun getMessageCount(conversationId: String): Int
}

View File

@@ -97,4 +97,19 @@ interface ApiService {
// ─── Feedback ───
@POST("api/v1/feedback")
suspend fun submitFeedback(@Body request: FeedbackRequest): FeedbackResponse
// ─── Chat History ───
@GET("api/v1/agent-chat/{agentId}/sessions/{sessionId}/messages")
suspend fun getSessionMessages(
@Path("agentId") agentId: String,
@Path("sessionId") sessionId: String,
@Query("before_id") beforeId: String? = null,
@Query("limit") limit: Int = 50
): MessageHistoryResponse
@GET("api/v1/agent-chat/{agentId}/sessions")
suspend fun getAgentSessions(
@Path("agentId") agentId: String,
@Query("limit") limit: Int = 50
): SessionListResponse
}

View File

@@ -275,6 +275,45 @@ data class RegisterResponse(
val message: String? = null
)
// ─────────── Feedback ───────────
// ─────────── Chat History ───────────
data class MessageItemDto(
val id: String,
@SerializedName("session_id") val sessionId: String,
@SerializedName("agent_id") val agentId: String? = null,
@SerializedName("user_id") val userId: String? = null,
val role: String,
val content: String? = null,
@SerializedName("tool_name") val toolName: String? = null,
@SerializedName("tool_input") val toolInput: String? = null,
@SerializedName("tool_output") val toolOutput: String? = null,
val iteration: Int = 0,
@SerializedName("created_at") val createdAt: String? = null
)
data class MessageHistoryResponse(
val messages: List<MessageItemDto>,
@SerializedName("has_more") val hasMore: Boolean,
val total: Int
)
data class SessionItemDto(
@SerializedName("session_id") val sessionId: String,
val title: String? = null,
@SerializedName("last_message") val lastMessage: String? = null,
@SerializedName("message_count") val messageCount: Int = 0,
@SerializedName("created_at") val createdAt: String? = null,
@SerializedName("updated_at") val updatedAt: String? = null
)
data class SessionListResponse(
val sessions: List<SessionItemDto>
)
// ─────────── Feedback ───────────
data class FeedbackRequest(

View File

@@ -8,6 +8,10 @@ import com.tiangong.aiagent.data.remote.ApiService
import com.tiangong.aiagent.data.remote.SseClient
import com.tiangong.aiagent.data.remote.dto.ChatRequest
import com.tiangong.aiagent.data.remote.dto.ChatResponse
import com.tiangong.aiagent.data.remote.dto.MessageHistoryResponse
import com.tiangong.aiagent.data.remote.dto.MessageItemDto
import com.tiangong.aiagent.data.remote.dto.SessionListResponse
import com.tiangong.aiagent.data.remote.dto.SessionItemDto
import com.tiangong.aiagent.data.remote.dto.TtsRequest
import com.tiangong.aiagent.domain.model.Conversation
import com.tiangong.aiagent.domain.model.Message
@@ -257,4 +261,136 @@ class ChatRepository @Inject constructor(
Result.failure(e)
}
}
// ─── Chat History (v1.3.0): Server-side pagination + Room cache ───
/** Fetch older messages from the server for cursor-based pagination. */
suspend fun fetchOlderMessages(
agentId: String,
sessionId: String,
beforeId: String? = null,
limit: Int = 50
): Result<Pair<List<Message>, Boolean>> {
return try {
val response: MessageHistoryResponse = apiService.getSessionMessages(
agentId = agentId,
sessionId = sessionId,
beforeId = beforeId,
limit = limit
)
// Cache fetched messages to local Room
cacheMessagesFromApi(response.messages, sessionId, agentId)
val messages = response.messages.map { it.toDomainMessage() }
Result.success(Pair(messages, response.hasMore))
} catch (e: Exception) {
Result.failure(e)
}
}
/** Fetch agent's session list from server. */
suspend fun fetchAgentSessions(agentId: String, limit: Int = 50): Result<List<SessionItemDto>> {
return try {
val response: SessionListResponse = apiService.getAgentSessions(agentId, limit)
Result.success(response.sessions)
} catch (e: Exception) {
Result.failure(e)
}
}
/** Get older messages from local Room (offline fallback). */
suspend fun getOlderMessagesFromRoom(
conversationId: String,
beforeTimestamp: Long,
limit: Int = 50
): List<Message> {
val entities = database.messageDao().getMessagesBeforeTimestamp(
conversationId = conversationId,
beforeTimestamp = beforeTimestamp,
limit = limit
)
return entities.map { it.toDomainMessage() }
}
/** Check if there are more messages available on the server for a session. */
suspend fun hasMoreServerMessages(agentId: String, sessionId: String): Boolean {
val localCount = database.messageDao().getMessageCount(sessionId)
return try {
val response = apiService.getSessionMessages(
agentId = agentId, sessionId = sessionId, beforeId = null, limit = 1
)
response.total > localCount
} catch (e: Exception) {
false
}
}
// ─── Helpers ───
private suspend fun cacheMessagesFromApi(dtos: List<MessageItemDto>, sessionId: String, agentId: String?) {
val dao = database.messageDao()
for (dto in dtos) {
dao.insert(
MessageEntity(
id = dto.id,
conversationId = sessionId,
agentId = dto.agentId ?: agentId,
role = dto.role,
content = dto.content ?: "",
toolName = dto.toolName,
toolInput = dto.toolInput,
toolOutput = dto.toolOutput,
tokenUsageJson = null,
createdAt = dto.createdAt?.let { parseIso8601ToEpoch(it) } ?: System.currentTimeMillis()
)
)
}
}
private fun MessageItemDto.toDomainMessage(): Message {
return Message(
id = id,
conversationId = sessionId,
agentId = agentId,
role = try {
Message.Role.valueOf(role.uppercase())
} catch (e: Exception) {
Message.Role.SYSTEM
},
content = content ?: "",
toolName = toolName,
toolInput = toolInput,
toolOutput = toolOutput,
createdAt = createdAt?.let { parseIso8601ToEpoch(it) } ?: System.currentTimeMillis()
)
}
private fun MessageEntity.toDomainMessage(): Message {
return Message(
id = id,
conversationId = conversationId,
agentId = agentId,
role = try {
Message.Role.valueOf(role.uppercase())
} catch (e: Exception) {
Message.Role.SYSTEM
},
content = content,
toolName = toolName,
toolInput = toolInput,
toolOutput = toolOutput,
createdAt = createdAt
)
}
companion object {
fun parseIso8601ToEpoch(isoString: String): Long {
return try {
val sdf = java.text.SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss", java.util.Locale.US)
sdf.timeZone = java.util.TimeZone.getTimeZone("UTC")
sdf.parse(isoString.substringBefore('.'))?.time ?: System.currentTimeMillis()
} catch (e: Exception) {
System.currentTimeMillis()
}
}
}
}

View File

@@ -74,13 +74,27 @@ fun ChatScreen(
onDispose { lifecycleOwner.lifecycle.removeObserver(observer) }
}
// Auto-scroll when new messages arrive
// Auto-scroll when new messages arrive (only when not loading more history)
LaunchedEffect(uiState.messages.size) {
if (uiState.messages.isNotEmpty()) {
if (uiState.messages.isNotEmpty() && !uiState.isLoadingMore) {
listState.animateScrollToItem(uiState.messages.size - 1)
}
}
// Scroll-to-top detection for loading more history
val isAtTop by remember {
derivedStateOf {
listState.firstVisibleItemIndex == 0 &&
listState.firstVisibleItemScrollOffset == 0
}
}
LaunchedEffect(isAtTop) {
if (isAtTop && uiState.hasMoreMessages && !uiState.isLoadingMore) {
viewModel.loadMoreHistory()
}
}
// Pre-fill input when editing a message
LaunchedEffect(uiState.editingMessageId) {
uiState.editingMessageContent?.let { inputText = it }
@@ -452,6 +466,28 @@ fun ChatScreen(
modifier = Modifier.fillMaxSize().weight(1f),
contentPadding = PaddingValues(vertical = 8.dp)
) {
// Loading indicator at top (pull-to-load more history)
if (uiState.isLoadingMore) {
item(key = "loading_more") {
Row(
modifier = Modifier.fillMaxWidth().padding(8.dp),
horizontalArrangement = Arrangement.Center,
verticalAlignment = Alignment.CenterVertically
) {
CircularProgressIndicator(
modifier = Modifier.size(20.dp),
strokeWidth = 2.dp
)
Spacer(modifier = Modifier.width(8.dp))
Text(
text = "加载更多...",
style = MaterialTheme.typography.bodySmall,
color = MaterialTheme.colorScheme.onSurfaceVariant
)
}
}
}
// Skeleton loading on first load
if (!firstLoaded && uiState.messages.isEmpty() && !uiState.isStreaming) {
item { SkeletonChat() }

View File

@@ -69,7 +69,11 @@ data class ChatUiState(
val pendingQueueCount: Int = 0,
// Think trace entries (v1.1.0)
val thinkTraces: List<ThinkTraceEntry> = emptyList()
val thinkTraces: List<ThinkTraceEntry> = emptyList(),
// Pull-to-load history (v1.3.0)
val isLoadingMore: Boolean = false,
val hasMoreMessages: Boolean = true
)
data class ThinkTraceEntry(
@@ -236,13 +240,34 @@ class ChatViewModel @Inject constructor(
historyFlowJob?.cancel()
historyFlowJob = viewModelScope.launch {
val conversations = chatRepository.getConversationsByAgent(agentId)
val latestConversation = conversations.firstOrNull() ?: return@launch
chatRepository.getMessages(latestConversation.sessionId).collect { messages ->
if (messages.isNotEmpty() && _uiState.value.currentAgent?.id == agentId) {
_uiState.value = _uiState.value.copy(
messages = messages.map { it.toUiMessage() },
sessionId = latestConversation.sessionId
val latestConversation = conversations.firstOrNull()
if (latestConversation != null) {
chatRepository.getMessages(latestConversation.sessionId).collect { messages ->
if (messages.isNotEmpty() && _uiState.value.currentAgent?.id == agentId) {
_uiState.value = _uiState.value.copy(
messages = messages.map { it.toUiMessage() },
sessionId = latestConversation.sessionId
)
}
}
} else {
// No local conversations — try fetching from server
val result = chatRepository.fetchAgentSessions(agentId, limit = 1)
if (result.isSuccess) {
val sessions = result.getOrThrow()
val latestSession = sessions.firstOrNull() ?: return@launch
val msgResult = chatRepository.fetchOlderMessages(
agentId = agentId, sessionId = latestSession.sessionId, limit = 50
)
if (msgResult.isSuccess) {
val (msgs, hasMore) = msgResult.getOrThrow()
_uiState.value = _uiState.value.copy(
messages = msgs.map { it.toUiMessage() },
sessionId = latestSession.sessionId,
hasMoreMessages = hasMore
)
tokenDataStore.saveLastSessionId(latestSession.sessionId)
}
}
}
}
@@ -271,16 +296,106 @@ class ChatViewModel @Inject constructor(
sseJob?.cancel()
viewModelScope.launch {
tokenDataStore.saveLastSessionId(sessionId)
chatRepository.getMessages(sessionId).collect { messages ->
if (messages.isNotEmpty()) {
// First check Room
val localMsgs = chatRepository.getMessages(sessionId).first()
if (localMsgs.isNotEmpty()) {
_uiState.value = _uiState.value.copy(
messages = localMsgs.map { it.toUiMessage() },
sessionId = sessionId,
streamingContent = "",
isLoading = false,
isStreaming = false,
error = null,
reconnectionState = ReconnectionState.Idle
)
// Continue observing local changes
historyFlowJob?.cancel()
historyFlowJob = viewModelScope.launch {
chatRepository.getMessages(sessionId).collect { messages ->
if (messages.isNotEmpty()) {
_uiState.value = _uiState.value.copy(
messages = messages.map { it.toUiMessage() }
)
}
}
}
} else {
// Try server
val agentId = _uiState.value.currentAgent?.id ?: return@launch
val result = chatRepository.fetchOlderMessages(
agentId = agentId, sessionId = sessionId, limit = 50
)
if (result.isSuccess) {
val (msgs, hasMore) = result.getOrThrow()
_uiState.value = _uiState.value.copy(
messages = messages.map { it.toUiMessage() },
messages = msgs.map { it.toUiMessage() },
sessionId = sessionId,
streamingContent = "",
isLoading = false,
isStreaming = false,
error = null,
reconnectionState = ReconnectionState.Idle
reconnectionState = ReconnectionState.Idle,
hasMoreMessages = hasMore
)
}
}
}
}
fun loadMoreHistory() {
val state = _uiState.value
if (state.isLoadingMore || !state.hasMoreMessages) return
val agentId = state.currentAgent?.id ?: return
val sessionId = state.sessionId ?: return
val oldestMsg = state.messages.firstOrNull() ?: return
viewModelScope.launch {
_uiState.value = _uiState.value.copy(isLoadingMore = true)
// Capture current scroll anchor (first visible item id + offset)
val anchorId = oldestMsg.id
val result = chatRepository.fetchOlderMessages(
agentId = agentId,
sessionId = sessionId,
beforeId = anchorId,
limit = 50
)
if (result.isSuccess) {
val (olderMessages, hasMore) = result.getOrThrow()
val current = _uiState.value.messages.toMutableList()
// Prepend older messages (avoid duplicates by id)
val existingIds = current.map { it.id }.toSet()
val newMsgs = olderMessages.filter { it.id !in existingIds }.map { it.toUiMessage() }
current.addAll(0, newMsgs)
_uiState.value = _uiState.value.copy(
messages = current,
hasMoreMessages = hasMore,
isLoadingMore = false
)
} else {
// API failed — try Room fallback
val oldestTimestamp = oldestMsg.createdAt
val roomMsgs = chatRepository.getOlderMessagesFromRoom(
conversationId = sessionId,
beforeTimestamp = oldestTimestamp,
limit = 50
)
if (roomMsgs.isNotEmpty()) {
val current = _uiState.value.messages.toMutableList()
val existingIds = current.map { it.id }.toSet()
val newMsgs = roomMsgs.filter { it.id !in existingIds }.map { it.toUiMessage() }
current.addAll(0, newMsgs)
_uiState.value = _uiState.value.copy(
messages = current,
hasMoreMessages = roomMsgs.size >= 50,
isLoadingMore = false
)
} else {
_uiState.value = _uiState.value.copy(
hasMoreMessages = false,
isLoadingMore = false
)
}
}

View File

@@ -0,0 +1,45 @@
"""024_add_chat_messages
Revision ID: 104e05fc9cf2
Revises: 15623091001e
Create Date: 2026-06-29 23:40:32.303965
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import mysql
# revision identifiers, used by Alembic.
revision = '104e05fc9cf2'
down_revision = '15623091001e'
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table('chat_messages',
sa.Column('id', mysql.CHAR(length=36), nullable=False, comment='消息ID'),
sa.Column('session_id', mysql.CHAR(length=36), nullable=False, comment='会话ID'),
sa.Column('agent_id', mysql.CHAR(length=36), nullable=True, comment='智能体ID'),
sa.Column('user_id', mysql.CHAR(length=36), nullable=True, comment='用户ID'),
sa.Column('role', sa.String(length=20), nullable=False, comment='角色: user/assistant/tool/system'),
sa.Column('content', sa.Text(), nullable=True, comment='消息内容'),
sa.Column('tool_name', sa.String(length=100), nullable=True, comment='工具名称'),
sa.Column('tool_input', sa.Text(), nullable=True, comment='工具输入参数JSON'),
sa.Column('tool_output', sa.Text(), nullable=True, comment='工具输出结果'),
sa.Column('iteration', sa.Integer(), nullable=True, comment='Agent迭代序号'),
sa.Column('created_at', sa.DateTime(), nullable=True, comment='创建时间'),
sa.ForeignKeyConstraint(['agent_id'], ['agents.id'], ondelete='SET NULL'),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='SET NULL'),
sa.PrimaryKeyConstraint('id'),
mysql_collate='utf8mb4_unicode_ci',
mysql_default_charset='utf8mb4',
mysql_engine='InnoDB'
)
op.create_index('ix_chat_messages_session_created', 'chat_messages', ['session_id', 'created_at'], unique=False)
def downgrade() -> None:
op.drop_index('ix_chat_messages_session_created', table_name='chat_messages')
op.drop_table('chat_messages')

View File

@@ -105,6 +105,7 @@ class AgentRuntime:
execution_logger: Optional[Any] = None,
on_tool_executed: Optional[Callable[[str], Any]] = None,
on_llm_call: Optional[Callable[[Dict[str, Any]], Any]] = None,
on_message: Optional[Callable[[Dict[str, Any]], Any]] = None,
hook_manager: Optional[HookManager] = None,
streamlined: bool = False,
):
@@ -141,6 +142,7 @@ class AgentRuntime:
self.execution_logger = execution_logger
self.on_tool_executed = on_tool_executed
self.on_llm_call = on_llm_call
self.on_message = on_message
self._memory_context_loaded = False
self._llm_invocations = 0
# 自主学习作用域bare 聊天用 "bare"Agent 用 "agent"
@@ -808,6 +810,12 @@ class AgentRuntime:
# 2. 追加用户消息
self.context.add_user_message(user_input)
if self.on_message:
self.on_message({
"role": "user", "content": user_input,
"session_id": self.context.session_id,
"iteration": 0,
})
# 2.5 计划模式 (P2) — 流式生成执行计划
plan: Optional[Plan] = None
@@ -969,6 +977,12 @@ class AgentRuntime:
# LLM 直接返回文本 → 结束
self.context.add_assistant_message(content)
final_text = content or "(模型未返回有效内容)"
if self.on_message:
self.on_message({
"role": "assistant", "content": final_text,
"session_id": self.context.session_id,
"iteration": self.context.iteration,
})
review_score = 0.0
# 输出质量自检(默认关闭)
@@ -1027,6 +1041,12 @@ class AgentRuntime:
# 有工具调用 → 先记录 assistant 消息
self.context.add_assistant_message(content or "", tool_calls, reasoning)
if self.on_message:
self.on_message({
"role": "assistant", "content": content or "",
"session_id": self.context.session_id,
"iteration": self.context.iteration,
})
# yield think 事件
tc_names = [tc["function"]["name"] for tc in tool_calls]
@@ -1087,6 +1107,15 @@ class AgentRuntime:
result = json.dumps({"error": hook_res.reason}, ensure_ascii=False)
yield {"type": "tool_result", "name": tname, "result": result, "iteration": self.context.iteration}
self.context.add_tool_result(tcid, tname, result)
if self.on_message:
self.on_message({
"role": "tool", "content": result,
"session_id": self.context.session_id,
"iteration": self.context.iteration,
"tool_name": tname,
"tool_input": json.dumps(targs, ensure_ascii=False) if targs else None,
"tool_output": result,
})
continue
if hook_res.modified_input:
targs = hook_res.modified_input
@@ -1119,11 +1148,29 @@ class AgentRuntime:
result = f"[审批拒绝] 工具 {tname} 需要人工审批但被拒绝。"
yield {"type": "tool_result", "name": tname, "result": result, "iteration": self.context.iteration}
self.context.add_tool_result(tcid, tname, result)
if self.on_message:
self.on_message({
"role": "tool", "content": result,
"session_id": self.context.session_id,
"iteration": self.context.iteration,
"tool_name": tname,
"tool_input": json.dumps(targs, ensure_ascii=False) if targs else None,
"tool_output": result,
})
continue
elif decision == "skip":
result = f"[审批跳过] 工具 {tname} 被跳过。"
yield {"type": "tool_result", "name": tname, "result": result, "iteration": self.context.iteration}
self.context.add_tool_result(tcid, tname, result)
if self.on_message:
self.on_message({
"role": "tool", "content": result,
"session_id": self.context.session_id,
"iteration": self.context.iteration,
"tool_name": tname,
"tool_input": json.dumps(targs, ensure_ascii=False) if targs else None,
"tool_output": result,
})
continue
# decision == "approved" → 继续执行
@@ -1154,6 +1201,15 @@ class AgentRuntime:
))
self.context.add_tool_result(tcid, tname, result)
if self.on_message:
self.on_message({
"role": "tool", "content": result,
"session_id": self.context.session_id,
"iteration": self.context.iteration,
"tool_name": tname,
"tool_input": json.dumps(targs, ensure_ascii=False) if targs else None,
"tool_output": result,
})
self.context.tool_calls_made += 1
# Hook: PostToolUse — 工具执行后处理 (流式)

View File

@@ -28,9 +28,11 @@ from app.agent_runtime import (
AgentBudgetConfig,
AgentMemoryConfig,
AgentStep,
AgentContext,
AgentOrchestrator,
OrchestratorAgentConfig,
)
from app.models.chat_message import ChatMessage
from app.core.config import settings
logger = logging.getLogger(__name__)
@@ -68,6 +70,32 @@ def _make_llm_logger(
return _log
def _make_message_saver(
db: Session,
agent_id: Optional[str] = None,
user_id: Optional[str] = None,
):
"""创建消息持久化回调,将每条消息写入 chat_messages 表。"""
def _save(msg: dict):
try:
record = ChatMessage(
session_id=msg.get("session_id"),
agent_id=agent_id,
user_id=user_id,
role=msg.get("role", "user"),
content=msg.get("content"),
tool_name=msg.get("tool_name"),
tool_input=msg.get("tool_input"),
tool_output=msg.get("tool_output"),
iteration=msg.get("iteration", 0),
)
db.add(record)
db.commit()
except Exception as e:
logger.warning("写入 ChatMessage 失败: %s", e)
return _save
async def _sse_stream(gen: AsyncGenerator[dict, None]) -> AsyncGenerator[str, None]:
"""将 run_stream 生成的 dict 事件格式化为 SSE 文本流。"""
async for event in gen:
@@ -99,6 +127,43 @@ class ChatResponse(BaseModel):
token_usage: Optional[Dict[str, Any]] = Field(default=None, description="Token 预算摘要")
class MessageItem(BaseModel):
"""消息历史条目"""
id: str
session_id: str
agent_id: Optional[str] = None
user_id: Optional[str] = None
role: str
content: Optional[str] = None
tool_name: Optional[str] = None
tool_input: Optional[str] = None
tool_output: Optional[str] = None
iteration: int = 0
created_at: Optional[str] = None
class MessageHistoryResponse(BaseModel):
"""消息历史分页响应"""
messages: List[MessageItem]
has_more: bool
total: int
class SessionItem(BaseModel):
"""会话列表条目"""
session_id: str
title: Optional[str] = None
last_message: Optional[str] = None
message_count: int = 0
created_at: Optional[str] = None
updated_at: Optional[str] = None
class SessionListResponse(BaseModel):
"""会话列表响应"""
sessions: List[SessionItem]
class OrchestrateAgentItem(BaseModel):
"""编排中单个 Agent 的定义"""
id: str
@@ -270,7 +335,9 @@ async def chat_bare(
if req.system_prompt_override:
config.system_prompt = req.system_prompt_override
on_llm_call = _make_llm_logger(db, agent_id=None, user_id=current_user.id)
runtime = AgentRuntime(config=config, on_llm_call=on_llm_call, streamlined=req.streamlined)
on_message = _make_message_saver(db, agent_id=None, user_id=current_user.id)
context = AgentContext(session_id=req.session_id)
runtime = AgentRuntime(config=config, context=context, on_llm_call=on_llm_call, on_message=on_message, streamlined=req.streamlined)
result = await runtime.run(req.message)
# 流式美化:为 steps 生成累计摘要
@@ -333,7 +400,9 @@ async def chat_bare_stream(
if req.system_prompt_override:
config.system_prompt = req.system_prompt_override
on_llm_call = _make_llm_logger(db, agent_id=None, user_id=current_user.id)
runtime = AgentRuntime(config=config, on_llm_call=on_llm_call, streamlined=req.streamlined)
on_message = _make_message_saver(db, agent_id=None, user_id=current_user.id)
context = AgentContext(session_id=req.session_id)
runtime = AgentRuntime(config=config, context=context, on_llm_call=on_llm_call, on_message=on_message, streamlined=req.streamlined)
return StreamingResponse(
_sse_stream(runtime.run_stream(req.message)),
media_type="text/event-stream",
@@ -417,7 +486,9 @@ async def chat_with_agent(
config.system_prompt = req.system_prompt_override
on_llm_call = _make_llm_logger(db, agent_id=agent_id, user_id=current_user.id)
runtime = AgentRuntime(config=config, on_llm_call=on_llm_call, streamlined=req.streamlined)
on_message = _make_message_saver(db, agent_id=agent_id, user_id=current_user.id)
context = AgentContext(session_id=req.session_id)
runtime = AgentRuntime(config=config, context=context, on_llm_call=on_llm_call, on_message=on_message, streamlined=req.streamlined)
result = await runtime.run(req.message)
# 流式美化:为 steps 生成累计摘要
@@ -511,7 +582,9 @@ async def chat_with_agent_stream(
config.system_prompt = req.system_prompt_override
on_llm_call = _make_llm_logger(db, agent_id=agent_id, user_id=current_user.id)
runtime = AgentRuntime(config=config, on_llm_call=on_llm_call, streamlined=req.streamlined)
on_message = _make_message_saver(db, agent_id=agent_id, user_id=current_user.id)
context = AgentContext(session_id=req.session_id)
runtime = AgentRuntime(config=config, context=context, on_llm_call=on_llm_call, on_message=on_message, streamlined=req.streamlined)
return StreamingResponse(
_sse_stream(runtime.run_stream(req.message)),
media_type="text/event-stream",
@@ -523,6 +596,127 @@ async def chat_with_agent_stream(
)
@router.get("/{agent_id}/sessions", response_model=SessionListResponse)
async def list_agent_sessions(
agent_id: str,
limit: int = 50,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""获取 Agent 的会话列表,按最近活跃时间排序。"""
from sqlalchemy import func as sa_func, desc
# 验证 agent 存在或有权限
agent = db.query(Agent).filter(Agent.id == agent_id).first()
if not agent:
raise HTTPException(status_code=404, detail="Agent 不存在")
rows = (
db.query(
ChatMessage.session_id,
sa_func.min(ChatMessage.created_at).label("created_at"),
sa_func.max(ChatMessage.created_at).label("updated_at"),
sa_func.count(ChatMessage.id).label("message_count"),
)
.filter(ChatMessage.agent_id == agent_id)
.group_by(ChatMessage.session_id)
.order_by(desc("updated_at"))
.limit(limit)
.all()
)
sessions = []
for row in rows:
# 取第一条 user 消息作为标题
first_user_msg = (
db.query(ChatMessage)
.filter(
ChatMessage.session_id == row.session_id,
ChatMessage.role == "user",
)
.order_by(ChatMessage.created_at.asc())
.first()
)
# 取最后一条消息作为预览
last_msg = (
db.query(ChatMessage)
.filter(ChatMessage.session_id == row.session_id)
.order_by(ChatMessage.created_at.desc())
.first()
)
sessions.append(SessionItem(
session_id=row.session_id,
title=first_user_msg.content[:100] if first_user_msg and first_user_msg.content else None,
last_message=last_msg.content[:200] if last_msg and last_msg.content else None,
message_count=row.message_count,
created_at=row.created_at.isoformat() if row.created_at else None,
updated_at=row.updated_at.isoformat() if row.updated_at else None,
))
return SessionListResponse(sessions=sessions)
@router.get("/{agent_id}/sessions/{session_id}/messages", response_model=MessageHistoryResponse)
async def get_session_messages(
agent_id: str,
session_id: str,
before_id: Optional[str] = None,
limit: int = 50,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""获取会话的消息历史(分页),从旧到新排序。"""
# limit 限制
limit = min(max(limit, 1), 200)
base_q = db.query(ChatMessage).filter(
ChatMessage.agent_id == agent_id,
ChatMessage.session_id == session_id,
)
# 游标分页before_id 之前的老消息
if before_id:
cursor_msg = db.query(ChatMessage).filter(ChatMessage.id == before_id).first()
if cursor_msg and cursor_msg.created_at:
base_q = base_q.filter(ChatMessage.created_at < cursor_msg.created_at)
# 取 N+1 条,判断 has_more按时间降序取最新 N 条,再反转)
batch = (
base_q
.order_by(ChatMessage.created_at.desc())
.limit(limit + 1)
.all()
)
has_more = len(batch) > limit
if has_more:
batch = batch[:limit]
# 反转为从旧到新
batch.reverse()
total = base_q.count()
messages = [
MessageItem(
id=m.id,
session_id=m.session_id,
agent_id=m.agent_id,
user_id=m.user_id,
role=m.role,
content=m.content,
tool_name=m.tool_name,
tool_input=m.tool_input,
tool_output=m.tool_output,
iteration=m.iteration or 0,
created_at=m.created_at.isoformat() if m.created_at else None,
)
for m in batch
]
return MessageHistoryResponse(messages=messages, has_more=has_more, total=total)
def _find_agent_node_config(nodes: list) -> Dict[str, Any]:
"""从工作流节点列表中查找第一个 agent 类型或 llm 类型的节点配置。"""
if not nodes:

View File

@@ -76,4 +76,5 @@ def init_db():
import app.models.workspace
import app.models.scene_contract
import app.models.team
import app.models.chat_message
Base.metadata.create_all(bind=engine)

View File

@@ -33,5 +33,6 @@ from app.models.audit_log import AuditLog
from app.models.workspace import Workspace, WorkspaceMembership
from app.models.scene_contract import SceneContract
from app.models.team import Team, TeamMember
from app.models.chat_message import ChatMessage
__all__ = ["User", "Workflow", "WorkflowVersion", "Agent", "GlobalKnowledge", "AgentRating", "AgentFavorite", "Execution", "ExecutionLog", "ModelConfig", "DataSource", "WorkflowTemplate", "TemplateRating", "TemplateFavorite", "NodeTemplate", "Role", "Permission", "WorkflowPermission", "AgentPermission", "AlertRule", "AlertLog", "PersistentUserMemory", "AgentLLMLog", "AgentVectorMemory", "AgentLearningPattern", "AgentSchedule", "KnowledgeBase", "Document", "DocumentChunk", "Notification", "UserFeishuOpenId", "NodePlugin", "OrchestrationTemplate", "Goal", "Task", "AgentExecutionLog", "UserBehaviorLog", "KnowledgeEntry", "UserFingerprint", "ShadowComparison", "FeedbackRecord", "AuditLog", "Workspace", "WorkspaceMembership", "SceneContract", "Team", "TeamMember"]
__all__ = ["User", "Workflow", "WorkflowVersion", "Agent", "GlobalKnowledge", "AgentRating", "AgentFavorite", "Execution", "ExecutionLog", "ModelConfig", "DataSource", "WorkflowTemplate", "TemplateRating", "TemplateFavorite", "NodeTemplate", "Role", "Permission", "WorkflowPermission", "AgentPermission", "AlertRule", "AlertLog", "PersistentUserMemory", "AgentLLMLog", "AgentVectorMemory", "AgentLearningPattern", "AgentSchedule", "KnowledgeBase", "Document", "DocumentChunk", "Notification", "UserFeishuOpenId", "NodePlugin", "OrchestrationTemplate", "Goal", "Task", "AgentExecutionLog", "UserBehaviorLog", "KnowledgeEntry", "UserFingerprint", "ShadowComparison", "FeedbackRecord", "AuditLog", "Workspace", "WorkspaceMembership", "SceneContract", "Team", "TeamMember", "ChatMessage"]

View File

@@ -0,0 +1,43 @@
"""
Chat Message 持久化模型 — 保存 Agent 会话中的每条消息
"""
from sqlalchemy import Column, String, Text, Integer, DateTime, ForeignKey, Index, func
from sqlalchemy.dialects.mysql import CHAR
from app.core.database import Base
import uuid
class ChatMessage(Base):
"""聊天消息表"""
__tablename__ = "chat_messages"
id = Column(CHAR(36), primary_key=True, default=lambda: str(uuid.uuid4()), comment="消息ID")
session_id = Column(CHAR(36), nullable=False, comment="会话ID")
agent_id = Column(CHAR(36), ForeignKey("agents.id", ondelete="SET NULL"), nullable=True, comment="智能体ID")
user_id = Column(CHAR(36), ForeignKey("users.id", ondelete="SET NULL"), nullable=True, comment="用户ID")
role = Column(String(20), nullable=False, comment="角色: user/assistant/tool/system")
content = Column(Text, comment="消息内容")
tool_name = Column(String(100), nullable=True, comment="工具名称(仅tool消息)")
tool_input = Column(Text, nullable=True, comment="工具输入参数JSON(仅tool消息)")
tool_output = Column(Text, nullable=True, comment="工具输出结果(仅tool消息)")
iteration = Column(Integer, default=0, comment="Agent 迭代序号")
created_at = Column(DateTime, default=func.now(), comment="创建时间")
__table_args__ = (
Index("ix_chat_messages_session_created", "session_id", "created_at"),
)
def to_dict(self):
return {
"id": self.id,
"session_id": self.session_id,
"agent_id": self.agent_id,
"user_id": self.user_id,
"role": self.role,
"content": self.content,
"tool_name": self.tool_name,
"tool_input": self.tool_input,
"tool_output": self.tool_output,
"iteration": self.iteration,
"created_at": self.created_at.isoformat() if self.created_at else None,
}