feat: persistent chat message storage + Android pull-to-load history
Backend:
- Add ChatMessage model + Alembic migration 024
- Add on_message callback to AgentRuntime for persisting messages during SSE streaming
- Plumb session_id from ChatRequest to AgentContext in all 4 chat endpoints
- Add GET /agent-chat/{id}/sessions and /sessions/{sid}/messages with cursor pagination
Android:
- Add DTOs/ApiService/MessageDao for server-side chat history
- ChatRepository: fetchOlderMessages (API + Room cache), offline fallback
- ChatViewModel: loadMoreHistory with isLoadingMore/hasMoreMessages state
- ChatScreen: scroll-to-top detection + top loading indicator
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -23,4 +23,10 @@ interface MessageDao {
|
||||
|
||||
@Query("DELETE FROM messages")
|
||||
suspend fun deleteAll()
|
||||
|
||||
@Query("SELECT * FROM messages WHERE conversationId = :conversationId AND createdAt < :beforeTimestamp ORDER BY createdAt DESC LIMIT :limit")
|
||||
suspend fun getMessagesBeforeTimestamp(conversationId: String, beforeTimestamp: Long, limit: Int): List<MessageEntity>
|
||||
|
||||
@Query("SELECT COUNT(*) FROM messages WHERE conversationId = :conversationId")
|
||||
suspend fun getMessageCount(conversationId: String): Int
|
||||
}
|
||||
|
||||
@@ -97,4 +97,19 @@ interface ApiService {
|
||||
// ─── Feedback ───
|
||||
@POST("api/v1/feedback")
|
||||
suspend fun submitFeedback(@Body request: FeedbackRequest): FeedbackResponse
|
||||
|
||||
// ─── Chat History ───
|
||||
@GET("api/v1/agent-chat/{agentId}/sessions/{sessionId}/messages")
|
||||
suspend fun getSessionMessages(
|
||||
@Path("agentId") agentId: String,
|
||||
@Path("sessionId") sessionId: String,
|
||||
@Query("before_id") beforeId: String? = null,
|
||||
@Query("limit") limit: Int = 50
|
||||
): MessageHistoryResponse
|
||||
|
||||
@GET("api/v1/agent-chat/{agentId}/sessions")
|
||||
suspend fun getAgentSessions(
|
||||
@Path("agentId") agentId: String,
|
||||
@Query("limit") limit: Int = 50
|
||||
): SessionListResponse
|
||||
}
|
||||
|
||||
@@ -275,6 +275,45 @@ data class RegisterResponse(
|
||||
val message: String? = null
|
||||
)
|
||||
|
||||
// ─────────── Feedback ───────────
|
||||
|
||||
|
||||
// ─────────── Chat History ───────────
|
||||
|
||||
data class MessageItemDto(
|
||||
val id: String,
|
||||
@SerializedName("session_id") val sessionId: String,
|
||||
@SerializedName("agent_id") val agentId: String? = null,
|
||||
@SerializedName("user_id") val userId: String? = null,
|
||||
val role: String,
|
||||
val content: String? = null,
|
||||
@SerializedName("tool_name") val toolName: String? = null,
|
||||
@SerializedName("tool_input") val toolInput: String? = null,
|
||||
@SerializedName("tool_output") val toolOutput: String? = null,
|
||||
val iteration: Int = 0,
|
||||
@SerializedName("created_at") val createdAt: String? = null
|
||||
)
|
||||
|
||||
data class MessageHistoryResponse(
|
||||
val messages: List<MessageItemDto>,
|
||||
@SerializedName("has_more") val hasMore: Boolean,
|
||||
val total: Int
|
||||
)
|
||||
|
||||
data class SessionItemDto(
|
||||
@SerializedName("session_id") val sessionId: String,
|
||||
val title: String? = null,
|
||||
@SerializedName("last_message") val lastMessage: String? = null,
|
||||
@SerializedName("message_count") val messageCount: Int = 0,
|
||||
@SerializedName("created_at") val createdAt: String? = null,
|
||||
@SerializedName("updated_at") val updatedAt: String? = null
|
||||
)
|
||||
|
||||
data class SessionListResponse(
|
||||
val sessions: List<SessionItemDto>
|
||||
)
|
||||
|
||||
|
||||
// ─────────── Feedback ───────────
|
||||
|
||||
data class FeedbackRequest(
|
||||
|
||||
@@ -8,6 +8,10 @@ import com.tiangong.aiagent.data.remote.ApiService
|
||||
import com.tiangong.aiagent.data.remote.SseClient
|
||||
import com.tiangong.aiagent.data.remote.dto.ChatRequest
|
||||
import com.tiangong.aiagent.data.remote.dto.ChatResponse
|
||||
import com.tiangong.aiagent.data.remote.dto.MessageHistoryResponse
|
||||
import com.tiangong.aiagent.data.remote.dto.MessageItemDto
|
||||
import com.tiangong.aiagent.data.remote.dto.SessionListResponse
|
||||
import com.tiangong.aiagent.data.remote.dto.SessionItemDto
|
||||
import com.tiangong.aiagent.data.remote.dto.TtsRequest
|
||||
import com.tiangong.aiagent.domain.model.Conversation
|
||||
import com.tiangong.aiagent.domain.model.Message
|
||||
@@ -257,4 +261,136 @@ class ChatRepository @Inject constructor(
|
||||
Result.failure(e)
|
||||
}
|
||||
}
|
||||
|
||||
// ─── Chat History (v1.3.0): Server-side pagination + Room cache ───
|
||||
|
||||
/** Fetch older messages from the server for cursor-based pagination. */
|
||||
suspend fun fetchOlderMessages(
|
||||
agentId: String,
|
||||
sessionId: String,
|
||||
beforeId: String? = null,
|
||||
limit: Int = 50
|
||||
): Result<Pair<List<Message>, Boolean>> {
|
||||
return try {
|
||||
val response: MessageHistoryResponse = apiService.getSessionMessages(
|
||||
agentId = agentId,
|
||||
sessionId = sessionId,
|
||||
beforeId = beforeId,
|
||||
limit = limit
|
||||
)
|
||||
// Cache fetched messages to local Room
|
||||
cacheMessagesFromApi(response.messages, sessionId, agentId)
|
||||
val messages = response.messages.map { it.toDomainMessage() }
|
||||
Result.success(Pair(messages, response.hasMore))
|
||||
} catch (e: Exception) {
|
||||
Result.failure(e)
|
||||
}
|
||||
}
|
||||
|
||||
/** Fetch agent's session list from server. */
|
||||
suspend fun fetchAgentSessions(agentId: String, limit: Int = 50): Result<List<SessionItemDto>> {
|
||||
return try {
|
||||
val response: SessionListResponse = apiService.getAgentSessions(agentId, limit)
|
||||
Result.success(response.sessions)
|
||||
} catch (e: Exception) {
|
||||
Result.failure(e)
|
||||
}
|
||||
}
|
||||
|
||||
/** Get older messages from local Room (offline fallback). */
|
||||
suspend fun getOlderMessagesFromRoom(
|
||||
conversationId: String,
|
||||
beforeTimestamp: Long,
|
||||
limit: Int = 50
|
||||
): List<Message> {
|
||||
val entities = database.messageDao().getMessagesBeforeTimestamp(
|
||||
conversationId = conversationId,
|
||||
beforeTimestamp = beforeTimestamp,
|
||||
limit = limit
|
||||
)
|
||||
return entities.map { it.toDomainMessage() }
|
||||
}
|
||||
|
||||
/** Check if there are more messages available on the server for a session. */
|
||||
suspend fun hasMoreServerMessages(agentId: String, sessionId: String): Boolean {
|
||||
val localCount = database.messageDao().getMessageCount(sessionId)
|
||||
return try {
|
||||
val response = apiService.getSessionMessages(
|
||||
agentId = agentId, sessionId = sessionId, beforeId = null, limit = 1
|
||||
)
|
||||
response.total > localCount
|
||||
} catch (e: Exception) {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
// ─── Helpers ───
|
||||
|
||||
private suspend fun cacheMessagesFromApi(dtos: List<MessageItemDto>, sessionId: String, agentId: String?) {
|
||||
val dao = database.messageDao()
|
||||
for (dto in dtos) {
|
||||
dao.insert(
|
||||
MessageEntity(
|
||||
id = dto.id,
|
||||
conversationId = sessionId,
|
||||
agentId = dto.agentId ?: agentId,
|
||||
role = dto.role,
|
||||
content = dto.content ?: "",
|
||||
toolName = dto.toolName,
|
||||
toolInput = dto.toolInput,
|
||||
toolOutput = dto.toolOutput,
|
||||
tokenUsageJson = null,
|
||||
createdAt = dto.createdAt?.let { parseIso8601ToEpoch(it) } ?: System.currentTimeMillis()
|
||||
)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
private fun MessageItemDto.toDomainMessage(): Message {
|
||||
return Message(
|
||||
id = id,
|
||||
conversationId = sessionId,
|
||||
agentId = agentId,
|
||||
role = try {
|
||||
Message.Role.valueOf(role.uppercase())
|
||||
} catch (e: Exception) {
|
||||
Message.Role.SYSTEM
|
||||
},
|
||||
content = content ?: "",
|
||||
toolName = toolName,
|
||||
toolInput = toolInput,
|
||||
toolOutput = toolOutput,
|
||||
createdAt = createdAt?.let { parseIso8601ToEpoch(it) } ?: System.currentTimeMillis()
|
||||
)
|
||||
}
|
||||
|
||||
private fun MessageEntity.toDomainMessage(): Message {
|
||||
return Message(
|
||||
id = id,
|
||||
conversationId = conversationId,
|
||||
agentId = agentId,
|
||||
role = try {
|
||||
Message.Role.valueOf(role.uppercase())
|
||||
} catch (e: Exception) {
|
||||
Message.Role.SYSTEM
|
||||
},
|
||||
content = content,
|
||||
toolName = toolName,
|
||||
toolInput = toolInput,
|
||||
toolOutput = toolOutput,
|
||||
createdAt = createdAt
|
||||
)
|
||||
}
|
||||
|
||||
companion object {
|
||||
fun parseIso8601ToEpoch(isoString: String): Long {
|
||||
return try {
|
||||
val sdf = java.text.SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss", java.util.Locale.US)
|
||||
sdf.timeZone = java.util.TimeZone.getTimeZone("UTC")
|
||||
sdf.parse(isoString.substringBefore('.'))?.time ?: System.currentTimeMillis()
|
||||
} catch (e: Exception) {
|
||||
System.currentTimeMillis()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -74,13 +74,27 @@ fun ChatScreen(
|
||||
onDispose { lifecycleOwner.lifecycle.removeObserver(observer) }
|
||||
}
|
||||
|
||||
// Auto-scroll when new messages arrive
|
||||
// Auto-scroll when new messages arrive (only when not loading more history)
|
||||
LaunchedEffect(uiState.messages.size) {
|
||||
if (uiState.messages.isNotEmpty()) {
|
||||
if (uiState.messages.isNotEmpty() && !uiState.isLoadingMore) {
|
||||
listState.animateScrollToItem(uiState.messages.size - 1)
|
||||
}
|
||||
}
|
||||
|
||||
// Scroll-to-top detection for loading more history
|
||||
val isAtTop by remember {
|
||||
derivedStateOf {
|
||||
listState.firstVisibleItemIndex == 0 &&
|
||||
listState.firstVisibleItemScrollOffset == 0
|
||||
}
|
||||
}
|
||||
|
||||
LaunchedEffect(isAtTop) {
|
||||
if (isAtTop && uiState.hasMoreMessages && !uiState.isLoadingMore) {
|
||||
viewModel.loadMoreHistory()
|
||||
}
|
||||
}
|
||||
|
||||
// Pre-fill input when editing a message
|
||||
LaunchedEffect(uiState.editingMessageId) {
|
||||
uiState.editingMessageContent?.let { inputText = it }
|
||||
@@ -452,6 +466,28 @@ fun ChatScreen(
|
||||
modifier = Modifier.fillMaxSize().weight(1f),
|
||||
contentPadding = PaddingValues(vertical = 8.dp)
|
||||
) {
|
||||
// Loading indicator at top (pull-to-load more history)
|
||||
if (uiState.isLoadingMore) {
|
||||
item(key = "loading_more") {
|
||||
Row(
|
||||
modifier = Modifier.fillMaxWidth().padding(8.dp),
|
||||
horizontalArrangement = Arrangement.Center,
|
||||
verticalAlignment = Alignment.CenterVertically
|
||||
) {
|
||||
CircularProgressIndicator(
|
||||
modifier = Modifier.size(20.dp),
|
||||
strokeWidth = 2.dp
|
||||
)
|
||||
Spacer(modifier = Modifier.width(8.dp))
|
||||
Text(
|
||||
text = "加载更多...",
|
||||
style = MaterialTheme.typography.bodySmall,
|
||||
color = MaterialTheme.colorScheme.onSurfaceVariant
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Skeleton loading on first load
|
||||
if (!firstLoaded && uiState.messages.isEmpty() && !uiState.isStreaming) {
|
||||
item { SkeletonChat() }
|
||||
|
||||
@@ -69,7 +69,11 @@ data class ChatUiState(
|
||||
val pendingQueueCount: Int = 0,
|
||||
|
||||
// Think trace entries (v1.1.0)
|
||||
val thinkTraces: List<ThinkTraceEntry> = emptyList()
|
||||
val thinkTraces: List<ThinkTraceEntry> = emptyList(),
|
||||
|
||||
// Pull-to-load history (v1.3.0)
|
||||
val isLoadingMore: Boolean = false,
|
||||
val hasMoreMessages: Boolean = true
|
||||
)
|
||||
|
||||
data class ThinkTraceEntry(
|
||||
@@ -236,13 +240,34 @@ class ChatViewModel @Inject constructor(
|
||||
historyFlowJob?.cancel()
|
||||
historyFlowJob = viewModelScope.launch {
|
||||
val conversations = chatRepository.getConversationsByAgent(agentId)
|
||||
val latestConversation = conversations.firstOrNull() ?: return@launch
|
||||
chatRepository.getMessages(latestConversation.sessionId).collect { messages ->
|
||||
if (messages.isNotEmpty() && _uiState.value.currentAgent?.id == agentId) {
|
||||
_uiState.value = _uiState.value.copy(
|
||||
messages = messages.map { it.toUiMessage() },
|
||||
sessionId = latestConversation.sessionId
|
||||
val latestConversation = conversations.firstOrNull()
|
||||
if (latestConversation != null) {
|
||||
chatRepository.getMessages(latestConversation.sessionId).collect { messages ->
|
||||
if (messages.isNotEmpty() && _uiState.value.currentAgent?.id == agentId) {
|
||||
_uiState.value = _uiState.value.copy(
|
||||
messages = messages.map { it.toUiMessage() },
|
||||
sessionId = latestConversation.sessionId
|
||||
)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// No local conversations — try fetching from server
|
||||
val result = chatRepository.fetchAgentSessions(agentId, limit = 1)
|
||||
if (result.isSuccess) {
|
||||
val sessions = result.getOrThrow()
|
||||
val latestSession = sessions.firstOrNull() ?: return@launch
|
||||
val msgResult = chatRepository.fetchOlderMessages(
|
||||
agentId = agentId, sessionId = latestSession.sessionId, limit = 50
|
||||
)
|
||||
if (msgResult.isSuccess) {
|
||||
val (msgs, hasMore) = msgResult.getOrThrow()
|
||||
_uiState.value = _uiState.value.copy(
|
||||
messages = msgs.map { it.toUiMessage() },
|
||||
sessionId = latestSession.sessionId,
|
||||
hasMoreMessages = hasMore
|
||||
)
|
||||
tokenDataStore.saveLastSessionId(latestSession.sessionId)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -271,16 +296,106 @@ class ChatViewModel @Inject constructor(
|
||||
sseJob?.cancel()
|
||||
viewModelScope.launch {
|
||||
tokenDataStore.saveLastSessionId(sessionId)
|
||||
chatRepository.getMessages(sessionId).collect { messages ->
|
||||
if (messages.isNotEmpty()) {
|
||||
// First check Room
|
||||
val localMsgs = chatRepository.getMessages(sessionId).first()
|
||||
if (localMsgs.isNotEmpty()) {
|
||||
_uiState.value = _uiState.value.copy(
|
||||
messages = localMsgs.map { it.toUiMessage() },
|
||||
sessionId = sessionId,
|
||||
streamingContent = "",
|
||||
isLoading = false,
|
||||
isStreaming = false,
|
||||
error = null,
|
||||
reconnectionState = ReconnectionState.Idle
|
||||
)
|
||||
// Continue observing local changes
|
||||
historyFlowJob?.cancel()
|
||||
historyFlowJob = viewModelScope.launch {
|
||||
chatRepository.getMessages(sessionId).collect { messages ->
|
||||
if (messages.isNotEmpty()) {
|
||||
_uiState.value = _uiState.value.copy(
|
||||
messages = messages.map { it.toUiMessage() }
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Try server
|
||||
val agentId = _uiState.value.currentAgent?.id ?: return@launch
|
||||
val result = chatRepository.fetchOlderMessages(
|
||||
agentId = agentId, sessionId = sessionId, limit = 50
|
||||
)
|
||||
if (result.isSuccess) {
|
||||
val (msgs, hasMore) = result.getOrThrow()
|
||||
_uiState.value = _uiState.value.copy(
|
||||
messages = messages.map { it.toUiMessage() },
|
||||
messages = msgs.map { it.toUiMessage() },
|
||||
sessionId = sessionId,
|
||||
streamingContent = "",
|
||||
isLoading = false,
|
||||
isStreaming = false,
|
||||
error = null,
|
||||
reconnectionState = ReconnectionState.Idle
|
||||
reconnectionState = ReconnectionState.Idle,
|
||||
hasMoreMessages = hasMore
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fun loadMoreHistory() {
|
||||
val state = _uiState.value
|
||||
if (state.isLoadingMore || !state.hasMoreMessages) return
|
||||
val agentId = state.currentAgent?.id ?: return
|
||||
val sessionId = state.sessionId ?: return
|
||||
val oldestMsg = state.messages.firstOrNull() ?: return
|
||||
|
||||
viewModelScope.launch {
|
||||
_uiState.value = _uiState.value.copy(isLoadingMore = true)
|
||||
|
||||
// Capture current scroll anchor (first visible item id + offset)
|
||||
val anchorId = oldestMsg.id
|
||||
|
||||
val result = chatRepository.fetchOlderMessages(
|
||||
agentId = agentId,
|
||||
sessionId = sessionId,
|
||||
beforeId = anchorId,
|
||||
limit = 50
|
||||
)
|
||||
|
||||
if (result.isSuccess) {
|
||||
val (olderMessages, hasMore) = result.getOrThrow()
|
||||
val current = _uiState.value.messages.toMutableList()
|
||||
// Prepend older messages (avoid duplicates by id)
|
||||
val existingIds = current.map { it.id }.toSet()
|
||||
val newMsgs = olderMessages.filter { it.id !in existingIds }.map { it.toUiMessage() }
|
||||
current.addAll(0, newMsgs)
|
||||
_uiState.value = _uiState.value.copy(
|
||||
messages = current,
|
||||
hasMoreMessages = hasMore,
|
||||
isLoadingMore = false
|
||||
)
|
||||
} else {
|
||||
// API failed — try Room fallback
|
||||
val oldestTimestamp = oldestMsg.createdAt
|
||||
val roomMsgs = chatRepository.getOlderMessagesFromRoom(
|
||||
conversationId = sessionId,
|
||||
beforeTimestamp = oldestTimestamp,
|
||||
limit = 50
|
||||
)
|
||||
if (roomMsgs.isNotEmpty()) {
|
||||
val current = _uiState.value.messages.toMutableList()
|
||||
val existingIds = current.map { it.id }.toSet()
|
||||
val newMsgs = roomMsgs.filter { it.id !in existingIds }.map { it.toUiMessage() }
|
||||
current.addAll(0, newMsgs)
|
||||
_uiState.value = _uiState.value.copy(
|
||||
messages = current,
|
||||
hasMoreMessages = roomMsgs.size >= 50,
|
||||
isLoadingMore = false
|
||||
)
|
||||
} else {
|
||||
_uiState.value = _uiState.value.copy(
|
||||
hasMoreMessages = false,
|
||||
isLoadingMore = false
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,45 @@
|
||||
"""024_add_chat_messages
|
||||
|
||||
Revision ID: 104e05fc9cf2
|
||||
Revises: 15623091001e
|
||||
Create Date: 2026-06-29 23:40:32.303965
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import mysql
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '104e05fc9cf2'
|
||||
down_revision = '15623091001e'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table('chat_messages',
|
||||
sa.Column('id', mysql.CHAR(length=36), nullable=False, comment='消息ID'),
|
||||
sa.Column('session_id', mysql.CHAR(length=36), nullable=False, comment='会话ID'),
|
||||
sa.Column('agent_id', mysql.CHAR(length=36), nullable=True, comment='智能体ID'),
|
||||
sa.Column('user_id', mysql.CHAR(length=36), nullable=True, comment='用户ID'),
|
||||
sa.Column('role', sa.String(length=20), nullable=False, comment='角色: user/assistant/tool/system'),
|
||||
sa.Column('content', sa.Text(), nullable=True, comment='消息内容'),
|
||||
sa.Column('tool_name', sa.String(length=100), nullable=True, comment='工具名称'),
|
||||
sa.Column('tool_input', sa.Text(), nullable=True, comment='工具输入参数JSON'),
|
||||
sa.Column('tool_output', sa.Text(), nullable=True, comment='工具输出结果'),
|
||||
sa.Column('iteration', sa.Integer(), nullable=True, comment='Agent迭代序号'),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=True, comment='创建时间'),
|
||||
sa.ForeignKeyConstraint(['agent_id'], ['agents.id'], ondelete='SET NULL'),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='SET NULL'),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
mysql_collate='utf8mb4_unicode_ci',
|
||||
mysql_default_charset='utf8mb4',
|
||||
mysql_engine='InnoDB'
|
||||
)
|
||||
op.create_index('ix_chat_messages_session_created', 'chat_messages', ['session_id', 'created_at'], unique=False)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index('ix_chat_messages_session_created', table_name='chat_messages')
|
||||
op.drop_table('chat_messages')
|
||||
@@ -105,6 +105,7 @@ class AgentRuntime:
|
||||
execution_logger: Optional[Any] = None,
|
||||
on_tool_executed: Optional[Callable[[str], Any]] = None,
|
||||
on_llm_call: Optional[Callable[[Dict[str, Any]], Any]] = None,
|
||||
on_message: Optional[Callable[[Dict[str, Any]], Any]] = None,
|
||||
hook_manager: Optional[HookManager] = None,
|
||||
streamlined: bool = False,
|
||||
):
|
||||
@@ -141,6 +142,7 @@ class AgentRuntime:
|
||||
self.execution_logger = execution_logger
|
||||
self.on_tool_executed = on_tool_executed
|
||||
self.on_llm_call = on_llm_call
|
||||
self.on_message = on_message
|
||||
self._memory_context_loaded = False
|
||||
self._llm_invocations = 0
|
||||
# 自主学习作用域:bare 聊天用 "bare",Agent 用 "agent"
|
||||
@@ -808,6 +810,12 @@ class AgentRuntime:
|
||||
|
||||
# 2. 追加用户消息
|
||||
self.context.add_user_message(user_input)
|
||||
if self.on_message:
|
||||
self.on_message({
|
||||
"role": "user", "content": user_input,
|
||||
"session_id": self.context.session_id,
|
||||
"iteration": 0,
|
||||
})
|
||||
|
||||
# 2.5 计划模式 (P2) — 流式生成执行计划
|
||||
plan: Optional[Plan] = None
|
||||
@@ -969,6 +977,12 @@ class AgentRuntime:
|
||||
# LLM 直接返回文本 → 结束
|
||||
self.context.add_assistant_message(content)
|
||||
final_text = content or "(模型未返回有效内容)"
|
||||
if self.on_message:
|
||||
self.on_message({
|
||||
"role": "assistant", "content": final_text,
|
||||
"session_id": self.context.session_id,
|
||||
"iteration": self.context.iteration,
|
||||
})
|
||||
review_score = 0.0
|
||||
|
||||
# 输出质量自检(默认关闭)
|
||||
@@ -1027,6 +1041,12 @@ class AgentRuntime:
|
||||
|
||||
# 有工具调用 → 先记录 assistant 消息
|
||||
self.context.add_assistant_message(content or "", tool_calls, reasoning)
|
||||
if self.on_message:
|
||||
self.on_message({
|
||||
"role": "assistant", "content": content or "",
|
||||
"session_id": self.context.session_id,
|
||||
"iteration": self.context.iteration,
|
||||
})
|
||||
|
||||
# yield think 事件
|
||||
tc_names = [tc["function"]["name"] for tc in tool_calls]
|
||||
@@ -1087,6 +1107,15 @@ class AgentRuntime:
|
||||
result = json.dumps({"error": hook_res.reason}, ensure_ascii=False)
|
||||
yield {"type": "tool_result", "name": tname, "result": result, "iteration": self.context.iteration}
|
||||
self.context.add_tool_result(tcid, tname, result)
|
||||
if self.on_message:
|
||||
self.on_message({
|
||||
"role": "tool", "content": result,
|
||||
"session_id": self.context.session_id,
|
||||
"iteration": self.context.iteration,
|
||||
"tool_name": tname,
|
||||
"tool_input": json.dumps(targs, ensure_ascii=False) if targs else None,
|
||||
"tool_output": result,
|
||||
})
|
||||
continue
|
||||
if hook_res.modified_input:
|
||||
targs = hook_res.modified_input
|
||||
@@ -1119,11 +1148,29 @@ class AgentRuntime:
|
||||
result = f"[审批拒绝] 工具 {tname} 需要人工审批但被拒绝。"
|
||||
yield {"type": "tool_result", "name": tname, "result": result, "iteration": self.context.iteration}
|
||||
self.context.add_tool_result(tcid, tname, result)
|
||||
if self.on_message:
|
||||
self.on_message({
|
||||
"role": "tool", "content": result,
|
||||
"session_id": self.context.session_id,
|
||||
"iteration": self.context.iteration,
|
||||
"tool_name": tname,
|
||||
"tool_input": json.dumps(targs, ensure_ascii=False) if targs else None,
|
||||
"tool_output": result,
|
||||
})
|
||||
continue
|
||||
elif decision == "skip":
|
||||
result = f"[审批跳过] 工具 {tname} 被跳过。"
|
||||
yield {"type": "tool_result", "name": tname, "result": result, "iteration": self.context.iteration}
|
||||
self.context.add_tool_result(tcid, tname, result)
|
||||
if self.on_message:
|
||||
self.on_message({
|
||||
"role": "tool", "content": result,
|
||||
"session_id": self.context.session_id,
|
||||
"iteration": self.context.iteration,
|
||||
"tool_name": tname,
|
||||
"tool_input": json.dumps(targs, ensure_ascii=False) if targs else None,
|
||||
"tool_output": result,
|
||||
})
|
||||
continue
|
||||
# decision == "approved" → 继续执行
|
||||
|
||||
@@ -1154,6 +1201,15 @@ class AgentRuntime:
|
||||
))
|
||||
|
||||
self.context.add_tool_result(tcid, tname, result)
|
||||
if self.on_message:
|
||||
self.on_message({
|
||||
"role": "tool", "content": result,
|
||||
"session_id": self.context.session_id,
|
||||
"iteration": self.context.iteration,
|
||||
"tool_name": tname,
|
||||
"tool_input": json.dumps(targs, ensure_ascii=False) if targs else None,
|
||||
"tool_output": result,
|
||||
})
|
||||
self.context.tool_calls_made += 1
|
||||
|
||||
# Hook: PostToolUse — 工具执行后处理 (流式)
|
||||
|
||||
@@ -28,9 +28,11 @@ from app.agent_runtime import (
|
||||
AgentBudgetConfig,
|
||||
AgentMemoryConfig,
|
||||
AgentStep,
|
||||
AgentContext,
|
||||
AgentOrchestrator,
|
||||
OrchestratorAgentConfig,
|
||||
)
|
||||
from app.models.chat_message import ChatMessage
|
||||
from app.core.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -68,6 +70,32 @@ def _make_llm_logger(
|
||||
return _log
|
||||
|
||||
|
||||
def _make_message_saver(
|
||||
db: Session,
|
||||
agent_id: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
):
|
||||
"""创建消息持久化回调,将每条消息写入 chat_messages 表。"""
|
||||
def _save(msg: dict):
|
||||
try:
|
||||
record = ChatMessage(
|
||||
session_id=msg.get("session_id"),
|
||||
agent_id=agent_id,
|
||||
user_id=user_id,
|
||||
role=msg.get("role", "user"),
|
||||
content=msg.get("content"),
|
||||
tool_name=msg.get("tool_name"),
|
||||
tool_input=msg.get("tool_input"),
|
||||
tool_output=msg.get("tool_output"),
|
||||
iteration=msg.get("iteration", 0),
|
||||
)
|
||||
db.add(record)
|
||||
db.commit()
|
||||
except Exception as e:
|
||||
logger.warning("写入 ChatMessage 失败: %s", e)
|
||||
return _save
|
||||
|
||||
|
||||
async def _sse_stream(gen: AsyncGenerator[dict, None]) -> AsyncGenerator[str, None]:
|
||||
"""将 run_stream 生成的 dict 事件格式化为 SSE 文本流。"""
|
||||
async for event in gen:
|
||||
@@ -99,6 +127,43 @@ class ChatResponse(BaseModel):
|
||||
token_usage: Optional[Dict[str, Any]] = Field(default=None, description="Token 预算摘要")
|
||||
|
||||
|
||||
class MessageItem(BaseModel):
|
||||
"""消息历史条目"""
|
||||
id: str
|
||||
session_id: str
|
||||
agent_id: Optional[str] = None
|
||||
user_id: Optional[str] = None
|
||||
role: str
|
||||
content: Optional[str] = None
|
||||
tool_name: Optional[str] = None
|
||||
tool_input: Optional[str] = None
|
||||
tool_output: Optional[str] = None
|
||||
iteration: int = 0
|
||||
created_at: Optional[str] = None
|
||||
|
||||
|
||||
class MessageHistoryResponse(BaseModel):
|
||||
"""消息历史分页响应"""
|
||||
messages: List[MessageItem]
|
||||
has_more: bool
|
||||
total: int
|
||||
|
||||
|
||||
class SessionItem(BaseModel):
|
||||
"""会话列表条目"""
|
||||
session_id: str
|
||||
title: Optional[str] = None
|
||||
last_message: Optional[str] = None
|
||||
message_count: int = 0
|
||||
created_at: Optional[str] = None
|
||||
updated_at: Optional[str] = None
|
||||
|
||||
|
||||
class SessionListResponse(BaseModel):
|
||||
"""会话列表响应"""
|
||||
sessions: List[SessionItem]
|
||||
|
||||
|
||||
class OrchestrateAgentItem(BaseModel):
|
||||
"""编排中单个 Agent 的定义"""
|
||||
id: str
|
||||
@@ -270,7 +335,9 @@ async def chat_bare(
|
||||
if req.system_prompt_override:
|
||||
config.system_prompt = req.system_prompt_override
|
||||
on_llm_call = _make_llm_logger(db, agent_id=None, user_id=current_user.id)
|
||||
runtime = AgentRuntime(config=config, on_llm_call=on_llm_call, streamlined=req.streamlined)
|
||||
on_message = _make_message_saver(db, agent_id=None, user_id=current_user.id)
|
||||
context = AgentContext(session_id=req.session_id)
|
||||
runtime = AgentRuntime(config=config, context=context, on_llm_call=on_llm_call, on_message=on_message, streamlined=req.streamlined)
|
||||
result = await runtime.run(req.message)
|
||||
|
||||
# 流式美化:为 steps 生成累计摘要
|
||||
@@ -333,7 +400,9 @@ async def chat_bare_stream(
|
||||
if req.system_prompt_override:
|
||||
config.system_prompt = req.system_prompt_override
|
||||
on_llm_call = _make_llm_logger(db, agent_id=None, user_id=current_user.id)
|
||||
runtime = AgentRuntime(config=config, on_llm_call=on_llm_call, streamlined=req.streamlined)
|
||||
on_message = _make_message_saver(db, agent_id=None, user_id=current_user.id)
|
||||
context = AgentContext(session_id=req.session_id)
|
||||
runtime = AgentRuntime(config=config, context=context, on_llm_call=on_llm_call, on_message=on_message, streamlined=req.streamlined)
|
||||
return StreamingResponse(
|
||||
_sse_stream(runtime.run_stream(req.message)),
|
||||
media_type="text/event-stream",
|
||||
@@ -417,7 +486,9 @@ async def chat_with_agent(
|
||||
config.system_prompt = req.system_prompt_override
|
||||
|
||||
on_llm_call = _make_llm_logger(db, agent_id=agent_id, user_id=current_user.id)
|
||||
runtime = AgentRuntime(config=config, on_llm_call=on_llm_call, streamlined=req.streamlined)
|
||||
on_message = _make_message_saver(db, agent_id=agent_id, user_id=current_user.id)
|
||||
context = AgentContext(session_id=req.session_id)
|
||||
runtime = AgentRuntime(config=config, context=context, on_llm_call=on_llm_call, on_message=on_message, streamlined=req.streamlined)
|
||||
result = await runtime.run(req.message)
|
||||
|
||||
# 流式美化:为 steps 生成累计摘要
|
||||
@@ -511,7 +582,9 @@ async def chat_with_agent_stream(
|
||||
config.system_prompt = req.system_prompt_override
|
||||
|
||||
on_llm_call = _make_llm_logger(db, agent_id=agent_id, user_id=current_user.id)
|
||||
runtime = AgentRuntime(config=config, on_llm_call=on_llm_call, streamlined=req.streamlined)
|
||||
on_message = _make_message_saver(db, agent_id=agent_id, user_id=current_user.id)
|
||||
context = AgentContext(session_id=req.session_id)
|
||||
runtime = AgentRuntime(config=config, context=context, on_llm_call=on_llm_call, on_message=on_message, streamlined=req.streamlined)
|
||||
return StreamingResponse(
|
||||
_sse_stream(runtime.run_stream(req.message)),
|
||||
media_type="text/event-stream",
|
||||
@@ -523,6 +596,127 @@ async def chat_with_agent_stream(
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{agent_id}/sessions", response_model=SessionListResponse)
|
||||
async def list_agent_sessions(
|
||||
agent_id: str,
|
||||
limit: int = 50,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取 Agent 的会话列表,按最近活跃时间排序。"""
|
||||
from sqlalchemy import func as sa_func, desc
|
||||
|
||||
# 验证 agent 存在或有权限
|
||||
agent = db.query(Agent).filter(Agent.id == agent_id).first()
|
||||
if not agent:
|
||||
raise HTTPException(status_code=404, detail="Agent 不存在")
|
||||
|
||||
rows = (
|
||||
db.query(
|
||||
ChatMessage.session_id,
|
||||
sa_func.min(ChatMessage.created_at).label("created_at"),
|
||||
sa_func.max(ChatMessage.created_at).label("updated_at"),
|
||||
sa_func.count(ChatMessage.id).label("message_count"),
|
||||
)
|
||||
.filter(ChatMessage.agent_id == agent_id)
|
||||
.group_by(ChatMessage.session_id)
|
||||
.order_by(desc("updated_at"))
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
sessions = []
|
||||
for row in rows:
|
||||
# 取第一条 user 消息作为标题
|
||||
first_user_msg = (
|
||||
db.query(ChatMessage)
|
||||
.filter(
|
||||
ChatMessage.session_id == row.session_id,
|
||||
ChatMessage.role == "user",
|
||||
)
|
||||
.order_by(ChatMessage.created_at.asc())
|
||||
.first()
|
||||
)
|
||||
# 取最后一条消息作为预览
|
||||
last_msg = (
|
||||
db.query(ChatMessage)
|
||||
.filter(ChatMessage.session_id == row.session_id)
|
||||
.order_by(ChatMessage.created_at.desc())
|
||||
.first()
|
||||
)
|
||||
sessions.append(SessionItem(
|
||||
session_id=row.session_id,
|
||||
title=first_user_msg.content[:100] if first_user_msg and first_user_msg.content else None,
|
||||
last_message=last_msg.content[:200] if last_msg and last_msg.content else None,
|
||||
message_count=row.message_count,
|
||||
created_at=row.created_at.isoformat() if row.created_at else None,
|
||||
updated_at=row.updated_at.isoformat() if row.updated_at else None,
|
||||
))
|
||||
|
||||
return SessionListResponse(sessions=sessions)
|
||||
|
||||
|
||||
@router.get("/{agent_id}/sessions/{session_id}/messages", response_model=MessageHistoryResponse)
|
||||
async def get_session_messages(
|
||||
agent_id: str,
|
||||
session_id: str,
|
||||
before_id: Optional[str] = None,
|
||||
limit: int = 50,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取会话的消息历史(分页),从旧到新排序。"""
|
||||
# limit 限制
|
||||
limit = min(max(limit, 1), 200)
|
||||
|
||||
base_q = db.query(ChatMessage).filter(
|
||||
ChatMessage.agent_id == agent_id,
|
||||
ChatMessage.session_id == session_id,
|
||||
)
|
||||
|
||||
# 游标分页:before_id 之前的老消息
|
||||
if before_id:
|
||||
cursor_msg = db.query(ChatMessage).filter(ChatMessage.id == before_id).first()
|
||||
if cursor_msg and cursor_msg.created_at:
|
||||
base_q = base_q.filter(ChatMessage.created_at < cursor_msg.created_at)
|
||||
|
||||
# 取 N+1 条,判断 has_more(按时间降序取最新 N 条,再反转)
|
||||
batch = (
|
||||
base_q
|
||||
.order_by(ChatMessage.created_at.desc())
|
||||
.limit(limit + 1)
|
||||
.all()
|
||||
)
|
||||
|
||||
has_more = len(batch) > limit
|
||||
if has_more:
|
||||
batch = batch[:limit]
|
||||
|
||||
# 反转为从旧到新
|
||||
batch.reverse()
|
||||
|
||||
total = base_q.count()
|
||||
|
||||
messages = [
|
||||
MessageItem(
|
||||
id=m.id,
|
||||
session_id=m.session_id,
|
||||
agent_id=m.agent_id,
|
||||
user_id=m.user_id,
|
||||
role=m.role,
|
||||
content=m.content,
|
||||
tool_name=m.tool_name,
|
||||
tool_input=m.tool_input,
|
||||
tool_output=m.tool_output,
|
||||
iteration=m.iteration or 0,
|
||||
created_at=m.created_at.isoformat() if m.created_at else None,
|
||||
)
|
||||
for m in batch
|
||||
]
|
||||
|
||||
return MessageHistoryResponse(messages=messages, has_more=has_more, total=total)
|
||||
|
||||
|
||||
def _find_agent_node_config(nodes: list) -> Dict[str, Any]:
|
||||
"""从工作流节点列表中查找第一个 agent 类型或 llm 类型的节点配置。"""
|
||||
if not nodes:
|
||||
|
||||
@@ -76,4 +76,5 @@ def init_db():
|
||||
import app.models.workspace
|
||||
import app.models.scene_contract
|
||||
import app.models.team
|
||||
import app.models.chat_message
|
||||
Base.metadata.create_all(bind=engine)
|
||||
|
||||
@@ -33,5 +33,6 @@ from app.models.audit_log import AuditLog
|
||||
from app.models.workspace import Workspace, WorkspaceMembership
|
||||
from app.models.scene_contract import SceneContract
|
||||
from app.models.team import Team, TeamMember
|
||||
from app.models.chat_message import ChatMessage
|
||||
|
||||
__all__ = ["User", "Workflow", "WorkflowVersion", "Agent", "GlobalKnowledge", "AgentRating", "AgentFavorite", "Execution", "ExecutionLog", "ModelConfig", "DataSource", "WorkflowTemplate", "TemplateRating", "TemplateFavorite", "NodeTemplate", "Role", "Permission", "WorkflowPermission", "AgentPermission", "AlertRule", "AlertLog", "PersistentUserMemory", "AgentLLMLog", "AgentVectorMemory", "AgentLearningPattern", "AgentSchedule", "KnowledgeBase", "Document", "DocumentChunk", "Notification", "UserFeishuOpenId", "NodePlugin", "OrchestrationTemplate", "Goal", "Task", "AgentExecutionLog", "UserBehaviorLog", "KnowledgeEntry", "UserFingerprint", "ShadowComparison", "FeedbackRecord", "AuditLog", "Workspace", "WorkspaceMembership", "SceneContract", "Team", "TeamMember"]
|
||||
__all__ = ["User", "Workflow", "WorkflowVersion", "Agent", "GlobalKnowledge", "AgentRating", "AgentFavorite", "Execution", "ExecutionLog", "ModelConfig", "DataSource", "WorkflowTemplate", "TemplateRating", "TemplateFavorite", "NodeTemplate", "Role", "Permission", "WorkflowPermission", "AgentPermission", "AlertRule", "AlertLog", "PersistentUserMemory", "AgentLLMLog", "AgentVectorMemory", "AgentLearningPattern", "AgentSchedule", "KnowledgeBase", "Document", "DocumentChunk", "Notification", "UserFeishuOpenId", "NodePlugin", "OrchestrationTemplate", "Goal", "Task", "AgentExecutionLog", "UserBehaviorLog", "KnowledgeEntry", "UserFingerprint", "ShadowComparison", "FeedbackRecord", "AuditLog", "Workspace", "WorkspaceMembership", "SceneContract", "Team", "TeamMember", "ChatMessage"]
|
||||
43
backend/app/models/chat_message.py
Normal file
43
backend/app/models/chat_message.py
Normal file
@@ -0,0 +1,43 @@
|
||||
"""
|
||||
Chat Message 持久化模型 — 保存 Agent 会话中的每条消息
|
||||
"""
|
||||
from sqlalchemy import Column, String, Text, Integer, DateTime, ForeignKey, Index, func
|
||||
from sqlalchemy.dialects.mysql import CHAR
|
||||
from app.core.database import Base
|
||||
import uuid
|
||||
|
||||
|
||||
class ChatMessage(Base):
|
||||
"""聊天消息表"""
|
||||
__tablename__ = "chat_messages"
|
||||
|
||||
id = Column(CHAR(36), primary_key=True, default=lambda: str(uuid.uuid4()), comment="消息ID")
|
||||
session_id = Column(CHAR(36), nullable=False, comment="会话ID")
|
||||
agent_id = Column(CHAR(36), ForeignKey("agents.id", ondelete="SET NULL"), nullable=True, comment="智能体ID")
|
||||
user_id = Column(CHAR(36), ForeignKey("users.id", ondelete="SET NULL"), nullable=True, comment="用户ID")
|
||||
role = Column(String(20), nullable=False, comment="角色: user/assistant/tool/system")
|
||||
content = Column(Text, comment="消息内容")
|
||||
tool_name = Column(String(100), nullable=True, comment="工具名称(仅tool消息)")
|
||||
tool_input = Column(Text, nullable=True, comment="工具输入参数JSON(仅tool消息)")
|
||||
tool_output = Column(Text, nullable=True, comment="工具输出结果(仅tool消息)")
|
||||
iteration = Column(Integer, default=0, comment="Agent 迭代序号")
|
||||
created_at = Column(DateTime, default=func.now(), comment="创建时间")
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_chat_messages_session_created", "session_id", "created_at"),
|
||||
)
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"id": self.id,
|
||||
"session_id": self.session_id,
|
||||
"agent_id": self.agent_id,
|
||||
"user_id": self.user_id,
|
||||
"role": self.role,
|
||||
"content": self.content,
|
||||
"tool_name": self.tool_name,
|
||||
"tool_input": self.tool_input,
|
||||
"tool_output": self.tool_output,
|
||||
"iteration": self.iteration,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
}
|
||||
Reference in New Issue
Block a user