From a06082480adc070a3a4c50e1f1eac5da4eeebb96 Mon Sep 17 00:00:00 2001 From: renjianbo <263303411@qq.com> Date: Tue, 30 Jun 2026 00:07:26 +0800 Subject: [PATCH] feat: persistent chat message storage + Android pull-to-load history Backend: - Add ChatMessage model + Alembic migration 024 - Add on_message callback to AgentRuntime for persisting messages during SSE streaming - Plumb session_id from ChatRequest to AgentContext in all 4 chat endpoints - Add GET /agent-chat/{id}/sessions and /sessions/{sid}/messages with cursor pagination Android: - Add DTOs/ApiService/MessageDao for server-side chat history - ChatRepository: fetchOlderMessages (API + Room cache), offline fallback - ChatViewModel: loadMoreHistory with isLoadingMore/hasMoreMessages state - ChatScreen: scroll-to-top detection + top loading indicator Co-Authored-By: Claude Opus 4.6 --- .../tiangong/aiagent/data/local/MessageDao.kt | 6 + .../aiagent/data/remote/ApiService.kt | 15 ++ .../tiangong/aiagent/data/remote/dto/Dtos.kt | 39 ++++ .../aiagent/data/repository/ChatRepository.kt | 136 ++++++++++++ .../tiangong/aiagent/ui/chat/ChatScreen.kt | 40 +++- .../tiangong/aiagent/ui/chat/ChatViewModel.kt | 137 +++++++++++- .../104e05fc9cf2_024_add_chat_messages.py | 45 ++++ backend/app/agent_runtime/core.py | 56 +++++ backend/app/api/agent_chat.py | 202 +++++++++++++++++- backend/app/core/database.py | 1 + backend/app/models/__init__.py | 3 +- backend/app/models/chat_message.py | 43 ++++ 12 files changed, 705 insertions(+), 18 deletions(-) create mode 100644 backend/alembic/versions/104e05fc9cf2_024_add_chat_messages.py create mode 100644 backend/app/models/chat_message.py diff --git a/android/app/src/main/java/com/tiangong/aiagent/data/local/MessageDao.kt b/android/app/src/main/java/com/tiangong/aiagent/data/local/MessageDao.kt index 2d5c829..2eab3ce 100644 --- a/android/app/src/main/java/com/tiangong/aiagent/data/local/MessageDao.kt +++ b/android/app/src/main/java/com/tiangong/aiagent/data/local/MessageDao.kt @@ -23,4 +23,10 @@ interface MessageDao { @Query("DELETE FROM messages") suspend fun deleteAll() + + @Query("SELECT * FROM messages WHERE conversationId = :conversationId AND createdAt < :beforeTimestamp ORDER BY createdAt DESC LIMIT :limit") + suspend fun getMessagesBeforeTimestamp(conversationId: String, beforeTimestamp: Long, limit: Int): List + + @Query("SELECT COUNT(*) FROM messages WHERE conversationId = :conversationId") + suspend fun getMessageCount(conversationId: String): Int } diff --git a/android/app/src/main/java/com/tiangong/aiagent/data/remote/ApiService.kt b/android/app/src/main/java/com/tiangong/aiagent/data/remote/ApiService.kt index 2652ebc..41b7e4a 100644 --- a/android/app/src/main/java/com/tiangong/aiagent/data/remote/ApiService.kt +++ b/android/app/src/main/java/com/tiangong/aiagent/data/remote/ApiService.kt @@ -97,4 +97,19 @@ interface ApiService { // ─── Feedback ─── @POST("api/v1/feedback") suspend fun submitFeedback(@Body request: FeedbackRequest): FeedbackResponse + + // ─── Chat History ─── + @GET("api/v1/agent-chat/{agentId}/sessions/{sessionId}/messages") + suspend fun getSessionMessages( + @Path("agentId") agentId: String, + @Path("sessionId") sessionId: String, + @Query("before_id") beforeId: String? = null, + @Query("limit") limit: Int = 50 + ): MessageHistoryResponse + + @GET("api/v1/agent-chat/{agentId}/sessions") + suspend fun getAgentSessions( + @Path("agentId") agentId: String, + @Query("limit") limit: Int = 50 + ): SessionListResponse } diff --git a/android/app/src/main/java/com/tiangong/aiagent/data/remote/dto/Dtos.kt b/android/app/src/main/java/com/tiangong/aiagent/data/remote/dto/Dtos.kt index 6835350..469d461 100644 --- a/android/app/src/main/java/com/tiangong/aiagent/data/remote/dto/Dtos.kt +++ b/android/app/src/main/java/com/tiangong/aiagent/data/remote/dto/Dtos.kt @@ -275,6 +275,45 @@ data class RegisterResponse( val message: String? = null ) +// ─────────── Feedback ─────────── + + +// ─────────── Chat History ─────────── + +data class MessageItemDto( + val id: String, + @SerializedName("session_id") val sessionId: String, + @SerializedName("agent_id") val agentId: String? = null, + @SerializedName("user_id") val userId: String? = null, + val role: String, + val content: String? = null, + @SerializedName("tool_name") val toolName: String? = null, + @SerializedName("tool_input") val toolInput: String? = null, + @SerializedName("tool_output") val toolOutput: String? = null, + val iteration: Int = 0, + @SerializedName("created_at") val createdAt: String? = null +) + +data class MessageHistoryResponse( + val messages: List, + @SerializedName("has_more") val hasMore: Boolean, + val total: Int +) + +data class SessionItemDto( + @SerializedName("session_id") val sessionId: String, + val title: String? = null, + @SerializedName("last_message") val lastMessage: String? = null, + @SerializedName("message_count") val messageCount: Int = 0, + @SerializedName("created_at") val createdAt: String? = null, + @SerializedName("updated_at") val updatedAt: String? = null +) + +data class SessionListResponse( + val sessions: List +) + + // ─────────── Feedback ─────────── data class FeedbackRequest( diff --git a/android/app/src/main/java/com/tiangong/aiagent/data/repository/ChatRepository.kt b/android/app/src/main/java/com/tiangong/aiagent/data/repository/ChatRepository.kt index 29f91d1..6691b45 100644 --- a/android/app/src/main/java/com/tiangong/aiagent/data/repository/ChatRepository.kt +++ b/android/app/src/main/java/com/tiangong/aiagent/data/repository/ChatRepository.kt @@ -8,6 +8,10 @@ import com.tiangong.aiagent.data.remote.ApiService import com.tiangong.aiagent.data.remote.SseClient import com.tiangong.aiagent.data.remote.dto.ChatRequest import com.tiangong.aiagent.data.remote.dto.ChatResponse +import com.tiangong.aiagent.data.remote.dto.MessageHistoryResponse +import com.tiangong.aiagent.data.remote.dto.MessageItemDto +import com.tiangong.aiagent.data.remote.dto.SessionListResponse +import com.tiangong.aiagent.data.remote.dto.SessionItemDto import com.tiangong.aiagent.data.remote.dto.TtsRequest import com.tiangong.aiagent.domain.model.Conversation import com.tiangong.aiagent.domain.model.Message @@ -257,4 +261,136 @@ class ChatRepository @Inject constructor( Result.failure(e) } } + + // ─── Chat History (v1.3.0): Server-side pagination + Room cache ─── + + /** Fetch older messages from the server for cursor-based pagination. */ + suspend fun fetchOlderMessages( + agentId: String, + sessionId: String, + beforeId: String? = null, + limit: Int = 50 + ): Result, Boolean>> { + return try { + val response: MessageHistoryResponse = apiService.getSessionMessages( + agentId = agentId, + sessionId = sessionId, + beforeId = beforeId, + limit = limit + ) + // Cache fetched messages to local Room + cacheMessagesFromApi(response.messages, sessionId, agentId) + val messages = response.messages.map { it.toDomainMessage() } + Result.success(Pair(messages, response.hasMore)) + } catch (e: Exception) { + Result.failure(e) + } + } + + /** Fetch agent's session list from server. */ + suspend fun fetchAgentSessions(agentId: String, limit: Int = 50): Result> { + return try { + val response: SessionListResponse = apiService.getAgentSessions(agentId, limit) + Result.success(response.sessions) + } catch (e: Exception) { + Result.failure(e) + } + } + + /** Get older messages from local Room (offline fallback). */ + suspend fun getOlderMessagesFromRoom( + conversationId: String, + beforeTimestamp: Long, + limit: Int = 50 + ): List { + val entities = database.messageDao().getMessagesBeforeTimestamp( + conversationId = conversationId, + beforeTimestamp = beforeTimestamp, + limit = limit + ) + return entities.map { it.toDomainMessage() } + } + + /** Check if there are more messages available on the server for a session. */ + suspend fun hasMoreServerMessages(agentId: String, sessionId: String): Boolean { + val localCount = database.messageDao().getMessageCount(sessionId) + return try { + val response = apiService.getSessionMessages( + agentId = agentId, sessionId = sessionId, beforeId = null, limit = 1 + ) + response.total > localCount + } catch (e: Exception) { + false + } + } + + // ─── Helpers ─── + + private suspend fun cacheMessagesFromApi(dtos: List, sessionId: String, agentId: String?) { + val dao = database.messageDao() + for (dto in dtos) { + dao.insert( + MessageEntity( + id = dto.id, + conversationId = sessionId, + agentId = dto.agentId ?: agentId, + role = dto.role, + content = dto.content ?: "", + toolName = dto.toolName, + toolInput = dto.toolInput, + toolOutput = dto.toolOutput, + tokenUsageJson = null, + createdAt = dto.createdAt?.let { parseIso8601ToEpoch(it) } ?: System.currentTimeMillis() + ) + ) + } + } + + private fun MessageItemDto.toDomainMessage(): Message { + return Message( + id = id, + conversationId = sessionId, + agentId = agentId, + role = try { + Message.Role.valueOf(role.uppercase()) + } catch (e: Exception) { + Message.Role.SYSTEM + }, + content = content ?: "", + toolName = toolName, + toolInput = toolInput, + toolOutput = toolOutput, + createdAt = createdAt?.let { parseIso8601ToEpoch(it) } ?: System.currentTimeMillis() + ) + } + + private fun MessageEntity.toDomainMessage(): Message { + return Message( + id = id, + conversationId = conversationId, + agentId = agentId, + role = try { + Message.Role.valueOf(role.uppercase()) + } catch (e: Exception) { + Message.Role.SYSTEM + }, + content = content, + toolName = toolName, + toolInput = toolInput, + toolOutput = toolOutput, + createdAt = createdAt + ) + } + + companion object { + fun parseIso8601ToEpoch(isoString: String): Long { + return try { + val sdf = java.text.SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss", java.util.Locale.US) + sdf.timeZone = java.util.TimeZone.getTimeZone("UTC") + sdf.parse(isoString.substringBefore('.'))?.time ?: System.currentTimeMillis() + } catch (e: Exception) { + System.currentTimeMillis() + } + } + } } diff --git a/android/app/src/main/java/com/tiangong/aiagent/ui/chat/ChatScreen.kt b/android/app/src/main/java/com/tiangong/aiagent/ui/chat/ChatScreen.kt index 452bc64..44bca43 100644 --- a/android/app/src/main/java/com/tiangong/aiagent/ui/chat/ChatScreen.kt +++ b/android/app/src/main/java/com/tiangong/aiagent/ui/chat/ChatScreen.kt @@ -74,13 +74,27 @@ fun ChatScreen( onDispose { lifecycleOwner.lifecycle.removeObserver(observer) } } - // Auto-scroll when new messages arrive + // Auto-scroll when new messages arrive (only when not loading more history) LaunchedEffect(uiState.messages.size) { - if (uiState.messages.isNotEmpty()) { + if (uiState.messages.isNotEmpty() && !uiState.isLoadingMore) { listState.animateScrollToItem(uiState.messages.size - 1) } } + // Scroll-to-top detection for loading more history + val isAtTop by remember { + derivedStateOf { + listState.firstVisibleItemIndex == 0 && + listState.firstVisibleItemScrollOffset == 0 + } + } + + LaunchedEffect(isAtTop) { + if (isAtTop && uiState.hasMoreMessages && !uiState.isLoadingMore) { + viewModel.loadMoreHistory() + } + } + // Pre-fill input when editing a message LaunchedEffect(uiState.editingMessageId) { uiState.editingMessageContent?.let { inputText = it } @@ -452,6 +466,28 @@ fun ChatScreen( modifier = Modifier.fillMaxSize().weight(1f), contentPadding = PaddingValues(vertical = 8.dp) ) { + // Loading indicator at top (pull-to-load more history) + if (uiState.isLoadingMore) { + item(key = "loading_more") { + Row( + modifier = Modifier.fillMaxWidth().padding(8.dp), + horizontalArrangement = Arrangement.Center, + verticalAlignment = Alignment.CenterVertically + ) { + CircularProgressIndicator( + modifier = Modifier.size(20.dp), + strokeWidth = 2.dp + ) + Spacer(modifier = Modifier.width(8.dp)) + Text( + text = "加载更多...", + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + } + } + } + // Skeleton loading on first load if (!firstLoaded && uiState.messages.isEmpty() && !uiState.isStreaming) { item { SkeletonChat() } diff --git a/android/app/src/main/java/com/tiangong/aiagent/ui/chat/ChatViewModel.kt b/android/app/src/main/java/com/tiangong/aiagent/ui/chat/ChatViewModel.kt index d0c7b7a..1163f67 100644 --- a/android/app/src/main/java/com/tiangong/aiagent/ui/chat/ChatViewModel.kt +++ b/android/app/src/main/java/com/tiangong/aiagent/ui/chat/ChatViewModel.kt @@ -69,7 +69,11 @@ data class ChatUiState( val pendingQueueCount: Int = 0, // Think trace entries (v1.1.0) - val thinkTraces: List = emptyList() + val thinkTraces: List = emptyList(), + + // Pull-to-load history (v1.3.0) + val isLoadingMore: Boolean = false, + val hasMoreMessages: Boolean = true ) data class ThinkTraceEntry( @@ -236,13 +240,34 @@ class ChatViewModel @Inject constructor( historyFlowJob?.cancel() historyFlowJob = viewModelScope.launch { val conversations = chatRepository.getConversationsByAgent(agentId) - val latestConversation = conversations.firstOrNull() ?: return@launch - chatRepository.getMessages(latestConversation.sessionId).collect { messages -> - if (messages.isNotEmpty() && _uiState.value.currentAgent?.id == agentId) { - _uiState.value = _uiState.value.copy( - messages = messages.map { it.toUiMessage() }, - sessionId = latestConversation.sessionId + val latestConversation = conversations.firstOrNull() + if (latestConversation != null) { + chatRepository.getMessages(latestConversation.sessionId).collect { messages -> + if (messages.isNotEmpty() && _uiState.value.currentAgent?.id == agentId) { + _uiState.value = _uiState.value.copy( + messages = messages.map { it.toUiMessage() }, + sessionId = latestConversation.sessionId + ) + } + } + } else { + // No local conversations — try fetching from server + val result = chatRepository.fetchAgentSessions(agentId, limit = 1) + if (result.isSuccess) { + val sessions = result.getOrThrow() + val latestSession = sessions.firstOrNull() ?: return@launch + val msgResult = chatRepository.fetchOlderMessages( + agentId = agentId, sessionId = latestSession.sessionId, limit = 50 ) + if (msgResult.isSuccess) { + val (msgs, hasMore) = msgResult.getOrThrow() + _uiState.value = _uiState.value.copy( + messages = msgs.map { it.toUiMessage() }, + sessionId = latestSession.sessionId, + hasMoreMessages = hasMore + ) + tokenDataStore.saveLastSessionId(latestSession.sessionId) + } } } } @@ -271,16 +296,106 @@ class ChatViewModel @Inject constructor( sseJob?.cancel() viewModelScope.launch { tokenDataStore.saveLastSessionId(sessionId) - chatRepository.getMessages(sessionId).collect { messages -> - if (messages.isNotEmpty()) { + // First check Room + val localMsgs = chatRepository.getMessages(sessionId).first() + if (localMsgs.isNotEmpty()) { + _uiState.value = _uiState.value.copy( + messages = localMsgs.map { it.toUiMessage() }, + sessionId = sessionId, + streamingContent = "", + isLoading = false, + isStreaming = false, + error = null, + reconnectionState = ReconnectionState.Idle + ) + // Continue observing local changes + historyFlowJob?.cancel() + historyFlowJob = viewModelScope.launch { + chatRepository.getMessages(sessionId).collect { messages -> + if (messages.isNotEmpty()) { + _uiState.value = _uiState.value.copy( + messages = messages.map { it.toUiMessage() } + ) + } + } + } + } else { + // Try server + val agentId = _uiState.value.currentAgent?.id ?: return@launch + val result = chatRepository.fetchOlderMessages( + agentId = agentId, sessionId = sessionId, limit = 50 + ) + if (result.isSuccess) { + val (msgs, hasMore) = result.getOrThrow() _uiState.value = _uiState.value.copy( - messages = messages.map { it.toUiMessage() }, + messages = msgs.map { it.toUiMessage() }, sessionId = sessionId, streamingContent = "", isLoading = false, isStreaming = false, error = null, - reconnectionState = ReconnectionState.Idle + reconnectionState = ReconnectionState.Idle, + hasMoreMessages = hasMore + ) + } + } + } + } + + fun loadMoreHistory() { + val state = _uiState.value + if (state.isLoadingMore || !state.hasMoreMessages) return + val agentId = state.currentAgent?.id ?: return + val sessionId = state.sessionId ?: return + val oldestMsg = state.messages.firstOrNull() ?: return + + viewModelScope.launch { + _uiState.value = _uiState.value.copy(isLoadingMore = true) + + // Capture current scroll anchor (first visible item id + offset) + val anchorId = oldestMsg.id + + val result = chatRepository.fetchOlderMessages( + agentId = agentId, + sessionId = sessionId, + beforeId = anchorId, + limit = 50 + ) + + if (result.isSuccess) { + val (olderMessages, hasMore) = result.getOrThrow() + val current = _uiState.value.messages.toMutableList() + // Prepend older messages (avoid duplicates by id) + val existingIds = current.map { it.id }.toSet() + val newMsgs = olderMessages.filter { it.id !in existingIds }.map { it.toUiMessage() } + current.addAll(0, newMsgs) + _uiState.value = _uiState.value.copy( + messages = current, + hasMoreMessages = hasMore, + isLoadingMore = false + ) + } else { + // API failed — try Room fallback + val oldestTimestamp = oldestMsg.createdAt + val roomMsgs = chatRepository.getOlderMessagesFromRoom( + conversationId = sessionId, + beforeTimestamp = oldestTimestamp, + limit = 50 + ) + if (roomMsgs.isNotEmpty()) { + val current = _uiState.value.messages.toMutableList() + val existingIds = current.map { it.id }.toSet() + val newMsgs = roomMsgs.filter { it.id !in existingIds }.map { it.toUiMessage() } + current.addAll(0, newMsgs) + _uiState.value = _uiState.value.copy( + messages = current, + hasMoreMessages = roomMsgs.size >= 50, + isLoadingMore = false + ) + } else { + _uiState.value = _uiState.value.copy( + hasMoreMessages = false, + isLoadingMore = false ) } } diff --git a/backend/alembic/versions/104e05fc9cf2_024_add_chat_messages.py b/backend/alembic/versions/104e05fc9cf2_024_add_chat_messages.py new file mode 100644 index 0000000..2a7aea4 --- /dev/null +++ b/backend/alembic/versions/104e05fc9cf2_024_add_chat_messages.py @@ -0,0 +1,45 @@ +"""024_add_chat_messages + +Revision ID: 104e05fc9cf2 +Revises: 15623091001e +Create Date: 2026-06-29 23:40:32.303965 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import mysql + + +# revision identifiers, used by Alembic. +revision = '104e05fc9cf2' +down_revision = '15623091001e' +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.create_table('chat_messages', + sa.Column('id', mysql.CHAR(length=36), nullable=False, comment='消息ID'), + sa.Column('session_id', mysql.CHAR(length=36), nullable=False, comment='会话ID'), + sa.Column('agent_id', mysql.CHAR(length=36), nullable=True, comment='智能体ID'), + sa.Column('user_id', mysql.CHAR(length=36), nullable=True, comment='用户ID'), + sa.Column('role', sa.String(length=20), nullable=False, comment='角色: user/assistant/tool/system'), + sa.Column('content', sa.Text(), nullable=True, comment='消息内容'), + sa.Column('tool_name', sa.String(length=100), nullable=True, comment='工具名称'), + sa.Column('tool_input', sa.Text(), nullable=True, comment='工具输入参数JSON'), + sa.Column('tool_output', sa.Text(), nullable=True, comment='工具输出结果'), + sa.Column('iteration', sa.Integer(), nullable=True, comment='Agent迭代序号'), + sa.Column('created_at', sa.DateTime(), nullable=True, comment='创建时间'), + sa.ForeignKeyConstraint(['agent_id'], ['agents.id'], ondelete='SET NULL'), + sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='SET NULL'), + sa.PrimaryKeyConstraint('id'), + mysql_collate='utf8mb4_unicode_ci', + mysql_default_charset='utf8mb4', + mysql_engine='InnoDB' + ) + op.create_index('ix_chat_messages_session_created', 'chat_messages', ['session_id', 'created_at'], unique=False) + + +def downgrade() -> None: + op.drop_index('ix_chat_messages_session_created', table_name='chat_messages') + op.drop_table('chat_messages') diff --git a/backend/app/agent_runtime/core.py b/backend/app/agent_runtime/core.py index d8fb371..53b2e9f 100644 --- a/backend/app/agent_runtime/core.py +++ b/backend/app/agent_runtime/core.py @@ -105,6 +105,7 @@ class AgentRuntime: execution_logger: Optional[Any] = None, on_tool_executed: Optional[Callable[[str], Any]] = None, on_llm_call: Optional[Callable[[Dict[str, Any]], Any]] = None, + on_message: Optional[Callable[[Dict[str, Any]], Any]] = None, hook_manager: Optional[HookManager] = None, streamlined: bool = False, ): @@ -141,6 +142,7 @@ class AgentRuntime: self.execution_logger = execution_logger self.on_tool_executed = on_tool_executed self.on_llm_call = on_llm_call + self.on_message = on_message self._memory_context_loaded = False self._llm_invocations = 0 # 自主学习作用域:bare 聊天用 "bare",Agent 用 "agent" @@ -808,6 +810,12 @@ class AgentRuntime: # 2. 追加用户消息 self.context.add_user_message(user_input) + if self.on_message: + self.on_message({ + "role": "user", "content": user_input, + "session_id": self.context.session_id, + "iteration": 0, + }) # 2.5 计划模式 (P2) — 流式生成执行计划 plan: Optional[Plan] = None @@ -969,6 +977,12 @@ class AgentRuntime: # LLM 直接返回文本 → 结束 self.context.add_assistant_message(content) final_text = content or "(模型未返回有效内容)" + if self.on_message: + self.on_message({ + "role": "assistant", "content": final_text, + "session_id": self.context.session_id, + "iteration": self.context.iteration, + }) review_score = 0.0 # 输出质量自检(默认关闭) @@ -1027,6 +1041,12 @@ class AgentRuntime: # 有工具调用 → 先记录 assistant 消息 self.context.add_assistant_message(content or "", tool_calls, reasoning) + if self.on_message: + self.on_message({ + "role": "assistant", "content": content or "", + "session_id": self.context.session_id, + "iteration": self.context.iteration, + }) # yield think 事件 tc_names = [tc["function"]["name"] for tc in tool_calls] @@ -1087,6 +1107,15 @@ class AgentRuntime: result = json.dumps({"error": hook_res.reason}, ensure_ascii=False) yield {"type": "tool_result", "name": tname, "result": result, "iteration": self.context.iteration} self.context.add_tool_result(tcid, tname, result) + if self.on_message: + self.on_message({ + "role": "tool", "content": result, + "session_id": self.context.session_id, + "iteration": self.context.iteration, + "tool_name": tname, + "tool_input": json.dumps(targs, ensure_ascii=False) if targs else None, + "tool_output": result, + }) continue if hook_res.modified_input: targs = hook_res.modified_input @@ -1119,11 +1148,29 @@ class AgentRuntime: result = f"[审批拒绝] 工具 {tname} 需要人工审批但被拒绝。" yield {"type": "tool_result", "name": tname, "result": result, "iteration": self.context.iteration} self.context.add_tool_result(tcid, tname, result) + if self.on_message: + self.on_message({ + "role": "tool", "content": result, + "session_id": self.context.session_id, + "iteration": self.context.iteration, + "tool_name": tname, + "tool_input": json.dumps(targs, ensure_ascii=False) if targs else None, + "tool_output": result, + }) continue elif decision == "skip": result = f"[审批跳过] 工具 {tname} 被跳过。" yield {"type": "tool_result", "name": tname, "result": result, "iteration": self.context.iteration} self.context.add_tool_result(tcid, tname, result) + if self.on_message: + self.on_message({ + "role": "tool", "content": result, + "session_id": self.context.session_id, + "iteration": self.context.iteration, + "tool_name": tname, + "tool_input": json.dumps(targs, ensure_ascii=False) if targs else None, + "tool_output": result, + }) continue # decision == "approved" → 继续执行 @@ -1154,6 +1201,15 @@ class AgentRuntime: )) self.context.add_tool_result(tcid, tname, result) + if self.on_message: + self.on_message({ + "role": "tool", "content": result, + "session_id": self.context.session_id, + "iteration": self.context.iteration, + "tool_name": tname, + "tool_input": json.dumps(targs, ensure_ascii=False) if targs else None, + "tool_output": result, + }) self.context.tool_calls_made += 1 # Hook: PostToolUse — 工具执行后处理 (流式) diff --git a/backend/app/api/agent_chat.py b/backend/app/api/agent_chat.py index f577ba6..19e5363 100644 --- a/backend/app/api/agent_chat.py +++ b/backend/app/api/agent_chat.py @@ -28,9 +28,11 @@ from app.agent_runtime import ( AgentBudgetConfig, AgentMemoryConfig, AgentStep, + AgentContext, AgentOrchestrator, OrchestratorAgentConfig, ) +from app.models.chat_message import ChatMessage from app.core.config import settings logger = logging.getLogger(__name__) @@ -68,6 +70,32 @@ def _make_llm_logger( return _log +def _make_message_saver( + db: Session, + agent_id: Optional[str] = None, + user_id: Optional[str] = None, +): + """创建消息持久化回调,将每条消息写入 chat_messages 表。""" + def _save(msg: dict): + try: + record = ChatMessage( + session_id=msg.get("session_id"), + agent_id=agent_id, + user_id=user_id, + role=msg.get("role", "user"), + content=msg.get("content"), + tool_name=msg.get("tool_name"), + tool_input=msg.get("tool_input"), + tool_output=msg.get("tool_output"), + iteration=msg.get("iteration", 0), + ) + db.add(record) + db.commit() + except Exception as e: + logger.warning("写入 ChatMessage 失败: %s", e) + return _save + + async def _sse_stream(gen: AsyncGenerator[dict, None]) -> AsyncGenerator[str, None]: """将 run_stream 生成的 dict 事件格式化为 SSE 文本流。""" async for event in gen: @@ -99,6 +127,43 @@ class ChatResponse(BaseModel): token_usage: Optional[Dict[str, Any]] = Field(default=None, description="Token 预算摘要") +class MessageItem(BaseModel): + """消息历史条目""" + id: str + session_id: str + agent_id: Optional[str] = None + user_id: Optional[str] = None + role: str + content: Optional[str] = None + tool_name: Optional[str] = None + tool_input: Optional[str] = None + tool_output: Optional[str] = None + iteration: int = 0 + created_at: Optional[str] = None + + +class MessageHistoryResponse(BaseModel): + """消息历史分页响应""" + messages: List[MessageItem] + has_more: bool + total: int + + +class SessionItem(BaseModel): + """会话列表条目""" + session_id: str + title: Optional[str] = None + last_message: Optional[str] = None + message_count: int = 0 + created_at: Optional[str] = None + updated_at: Optional[str] = None + + +class SessionListResponse(BaseModel): + """会话列表响应""" + sessions: List[SessionItem] + + class OrchestrateAgentItem(BaseModel): """编排中单个 Agent 的定义""" id: str @@ -270,7 +335,9 @@ async def chat_bare( if req.system_prompt_override: config.system_prompt = req.system_prompt_override on_llm_call = _make_llm_logger(db, agent_id=None, user_id=current_user.id) - runtime = AgentRuntime(config=config, on_llm_call=on_llm_call, streamlined=req.streamlined) + on_message = _make_message_saver(db, agent_id=None, user_id=current_user.id) + context = AgentContext(session_id=req.session_id) + runtime = AgentRuntime(config=config, context=context, on_llm_call=on_llm_call, on_message=on_message, streamlined=req.streamlined) result = await runtime.run(req.message) # 流式美化:为 steps 生成累计摘要 @@ -333,7 +400,9 @@ async def chat_bare_stream( if req.system_prompt_override: config.system_prompt = req.system_prompt_override on_llm_call = _make_llm_logger(db, agent_id=None, user_id=current_user.id) - runtime = AgentRuntime(config=config, on_llm_call=on_llm_call, streamlined=req.streamlined) + on_message = _make_message_saver(db, agent_id=None, user_id=current_user.id) + context = AgentContext(session_id=req.session_id) + runtime = AgentRuntime(config=config, context=context, on_llm_call=on_llm_call, on_message=on_message, streamlined=req.streamlined) return StreamingResponse( _sse_stream(runtime.run_stream(req.message)), media_type="text/event-stream", @@ -417,7 +486,9 @@ async def chat_with_agent( config.system_prompt = req.system_prompt_override on_llm_call = _make_llm_logger(db, agent_id=agent_id, user_id=current_user.id) - runtime = AgentRuntime(config=config, on_llm_call=on_llm_call, streamlined=req.streamlined) + on_message = _make_message_saver(db, agent_id=agent_id, user_id=current_user.id) + context = AgentContext(session_id=req.session_id) + runtime = AgentRuntime(config=config, context=context, on_llm_call=on_llm_call, on_message=on_message, streamlined=req.streamlined) result = await runtime.run(req.message) # 流式美化:为 steps 生成累计摘要 @@ -511,7 +582,9 @@ async def chat_with_agent_stream( config.system_prompt = req.system_prompt_override on_llm_call = _make_llm_logger(db, agent_id=agent_id, user_id=current_user.id) - runtime = AgentRuntime(config=config, on_llm_call=on_llm_call, streamlined=req.streamlined) + on_message = _make_message_saver(db, agent_id=agent_id, user_id=current_user.id) + context = AgentContext(session_id=req.session_id) + runtime = AgentRuntime(config=config, context=context, on_llm_call=on_llm_call, on_message=on_message, streamlined=req.streamlined) return StreamingResponse( _sse_stream(runtime.run_stream(req.message)), media_type="text/event-stream", @@ -523,6 +596,127 @@ async def chat_with_agent_stream( ) +@router.get("/{agent_id}/sessions", response_model=SessionListResponse) +async def list_agent_sessions( + agent_id: str, + limit: int = 50, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), +): + """获取 Agent 的会话列表,按最近活跃时间排序。""" + from sqlalchemy import func as sa_func, desc + + # 验证 agent 存在或有权限 + agent = db.query(Agent).filter(Agent.id == agent_id).first() + if not agent: + raise HTTPException(status_code=404, detail="Agent 不存在") + + rows = ( + db.query( + ChatMessage.session_id, + sa_func.min(ChatMessage.created_at).label("created_at"), + sa_func.max(ChatMessage.created_at).label("updated_at"), + sa_func.count(ChatMessage.id).label("message_count"), + ) + .filter(ChatMessage.agent_id == agent_id) + .group_by(ChatMessage.session_id) + .order_by(desc("updated_at")) + .limit(limit) + .all() + ) + + sessions = [] + for row in rows: + # 取第一条 user 消息作为标题 + first_user_msg = ( + db.query(ChatMessage) + .filter( + ChatMessage.session_id == row.session_id, + ChatMessage.role == "user", + ) + .order_by(ChatMessage.created_at.asc()) + .first() + ) + # 取最后一条消息作为预览 + last_msg = ( + db.query(ChatMessage) + .filter(ChatMessage.session_id == row.session_id) + .order_by(ChatMessage.created_at.desc()) + .first() + ) + sessions.append(SessionItem( + session_id=row.session_id, + title=first_user_msg.content[:100] if first_user_msg and first_user_msg.content else None, + last_message=last_msg.content[:200] if last_msg and last_msg.content else None, + message_count=row.message_count, + created_at=row.created_at.isoformat() if row.created_at else None, + updated_at=row.updated_at.isoformat() if row.updated_at else None, + )) + + return SessionListResponse(sessions=sessions) + + +@router.get("/{agent_id}/sessions/{session_id}/messages", response_model=MessageHistoryResponse) +async def get_session_messages( + agent_id: str, + session_id: str, + before_id: Optional[str] = None, + limit: int = 50, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), +): + """获取会话的消息历史(分页),从旧到新排序。""" + # limit 限制 + limit = min(max(limit, 1), 200) + + base_q = db.query(ChatMessage).filter( + ChatMessage.agent_id == agent_id, + ChatMessage.session_id == session_id, + ) + + # 游标分页:before_id 之前的老消息 + if before_id: + cursor_msg = db.query(ChatMessage).filter(ChatMessage.id == before_id).first() + if cursor_msg and cursor_msg.created_at: + base_q = base_q.filter(ChatMessage.created_at < cursor_msg.created_at) + + # 取 N+1 条,判断 has_more(按时间降序取最新 N 条,再反转) + batch = ( + base_q + .order_by(ChatMessage.created_at.desc()) + .limit(limit + 1) + .all() + ) + + has_more = len(batch) > limit + if has_more: + batch = batch[:limit] + + # 反转为从旧到新 + batch.reverse() + + total = base_q.count() + + messages = [ + MessageItem( + id=m.id, + session_id=m.session_id, + agent_id=m.agent_id, + user_id=m.user_id, + role=m.role, + content=m.content, + tool_name=m.tool_name, + tool_input=m.tool_input, + tool_output=m.tool_output, + iteration=m.iteration or 0, + created_at=m.created_at.isoformat() if m.created_at else None, + ) + for m in batch + ] + + return MessageHistoryResponse(messages=messages, has_more=has_more, total=total) + + def _find_agent_node_config(nodes: list) -> Dict[str, Any]: """从工作流节点列表中查找第一个 agent 类型或 llm 类型的节点配置。""" if not nodes: diff --git a/backend/app/core/database.py b/backend/app/core/database.py index 495911c..3ee95bd 100644 --- a/backend/app/core/database.py +++ b/backend/app/core/database.py @@ -76,4 +76,5 @@ def init_db(): import app.models.workspace import app.models.scene_contract import app.models.team + import app.models.chat_message Base.metadata.create_all(bind=engine) diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py index bedb203..7dbae02 100644 --- a/backend/app/models/__init__.py +++ b/backend/app/models/__init__.py @@ -33,5 +33,6 @@ from app.models.audit_log import AuditLog from app.models.workspace import Workspace, WorkspaceMembership from app.models.scene_contract import SceneContract from app.models.team import Team, TeamMember +from app.models.chat_message import ChatMessage -__all__ = ["User", "Workflow", "WorkflowVersion", "Agent", "GlobalKnowledge", "AgentRating", "AgentFavorite", "Execution", "ExecutionLog", "ModelConfig", "DataSource", "WorkflowTemplate", "TemplateRating", "TemplateFavorite", "NodeTemplate", "Role", "Permission", "WorkflowPermission", "AgentPermission", "AlertRule", "AlertLog", "PersistentUserMemory", "AgentLLMLog", "AgentVectorMemory", "AgentLearningPattern", "AgentSchedule", "KnowledgeBase", "Document", "DocumentChunk", "Notification", "UserFeishuOpenId", "NodePlugin", "OrchestrationTemplate", "Goal", "Task", "AgentExecutionLog", "UserBehaviorLog", "KnowledgeEntry", "UserFingerprint", "ShadowComparison", "FeedbackRecord", "AuditLog", "Workspace", "WorkspaceMembership", "SceneContract", "Team", "TeamMember"] \ No newline at end of file +__all__ = ["User", "Workflow", "WorkflowVersion", "Agent", "GlobalKnowledge", "AgentRating", "AgentFavorite", "Execution", "ExecutionLog", "ModelConfig", "DataSource", "WorkflowTemplate", "TemplateRating", "TemplateFavorite", "NodeTemplate", "Role", "Permission", "WorkflowPermission", "AgentPermission", "AlertRule", "AlertLog", "PersistentUserMemory", "AgentLLMLog", "AgentVectorMemory", "AgentLearningPattern", "AgentSchedule", "KnowledgeBase", "Document", "DocumentChunk", "Notification", "UserFeishuOpenId", "NodePlugin", "OrchestrationTemplate", "Goal", "Task", "AgentExecutionLog", "UserBehaviorLog", "KnowledgeEntry", "UserFingerprint", "ShadowComparison", "FeedbackRecord", "AuditLog", "Workspace", "WorkspaceMembership", "SceneContract", "Team", "TeamMember", "ChatMessage"] \ No newline at end of file diff --git a/backend/app/models/chat_message.py b/backend/app/models/chat_message.py new file mode 100644 index 0000000..d04f5d4 --- /dev/null +++ b/backend/app/models/chat_message.py @@ -0,0 +1,43 @@ +""" +Chat Message 持久化模型 — 保存 Agent 会话中的每条消息 +""" +from sqlalchemy import Column, String, Text, Integer, DateTime, ForeignKey, Index, func +from sqlalchemy.dialects.mysql import CHAR +from app.core.database import Base +import uuid + + +class ChatMessage(Base): + """聊天消息表""" + __tablename__ = "chat_messages" + + id = Column(CHAR(36), primary_key=True, default=lambda: str(uuid.uuid4()), comment="消息ID") + session_id = Column(CHAR(36), nullable=False, comment="会话ID") + agent_id = Column(CHAR(36), ForeignKey("agents.id", ondelete="SET NULL"), nullable=True, comment="智能体ID") + user_id = Column(CHAR(36), ForeignKey("users.id", ondelete="SET NULL"), nullable=True, comment="用户ID") + role = Column(String(20), nullable=False, comment="角色: user/assistant/tool/system") + content = Column(Text, comment="消息内容") + tool_name = Column(String(100), nullable=True, comment="工具名称(仅tool消息)") + tool_input = Column(Text, nullable=True, comment="工具输入参数JSON(仅tool消息)") + tool_output = Column(Text, nullable=True, comment="工具输出结果(仅tool消息)") + iteration = Column(Integer, default=0, comment="Agent 迭代序号") + created_at = Column(DateTime, default=func.now(), comment="创建时间") + + __table_args__ = ( + Index("ix_chat_messages_session_created", "session_id", "created_at"), + ) + + def to_dict(self): + return { + "id": self.id, + "session_id": self.session_id, + "agent_id": self.agent_id, + "user_id": self.user_id, + "role": self.role, + "content": self.content, + "tool_name": self.tool_name, + "tool_input": self.tool_input, + "tool_output": self.tool_output, + "iteration": self.iteration, + "created_at": self.created_at.isoformat() if self.created_at else None, + }