Files
aiagent/scripts/seed_coding_agent.py
renjianbo 7aba0f9bc5 fix: 修复 Agent 流式对话无响应和工具 schema 兼容性问题
- 在 `run_stream()` LLM 调用前 yield `think` 事件,前端即时显示"思考中..."
- 修复 tool schema 规范化逻辑:`{"function":{...}}` 格式缺少 `type` 字段导致 LLM API 拒绝
- 启动时从数据库加载自定义工具(`load_tools_from_db`),解决重启后工具丢失
- 前端 SSE 添加 60s 超时保护,任何事件类型均触发 `receivedFirstEvent`
- 流式失败自动降级到非流式 POST
- 添加 `scripts/seed_coding_agent.py` 和 `scripts/test_coding_agent.py`

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-02 00:38:41 +08:00

362 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""创建代码编程助手:开发者工具 + Agent 配置"""
import json
import urllib.request
import urllib.parse
import uuid
BASE = "http://localhost:8037"
def req(method, path, headers=None, body=None, raw_body=None):
hdrs = {"Content-Type": "application/json"}
if headers: hdrs.update(headers)
data = raw_body if raw_body else (json.dumps(body).encode() if body else None)
r = urllib.request.Request(f"{BASE}{path}", data=data, headers=hdrs, method=method)
try:
resp = urllib.request.urlopen(r, timeout=15)
return resp.status, json.loads(resp.read())
except urllib.request.HTTPError as e:
return e.code, json.loads(e.read())
except Exception as e:
return 0, {"error": str(e)}
# 1. 登录/注册
_, _ = req("POST", "/api/v1/auth/register", body={
"username": "codingbot", "email": "coding@test.com", "password": "test123456"
})
status, login_data = req("POST", "/api/v1/auth/login",
headers={"Content-Type": "application/x-www-form-urlencoded"},
raw_body=urllib.parse.urlencode({"username": "codingbot", "password": "test123456"}).encode())
if status != 200:
print(f"Login failed: {login_data}")
exit(1)
token = login_data["access_token"]
auth = {"Authorization": f"Bearer {token}"}
print("OK 用户已登录")
# 2. 创建开发者工具
dev_tools = [
{
"name": "execute_code",
"description": "在安全沙箱中执行Python代码返回执行结果和stdout/stderr。支持任意Python代码可用于运行脚本、测试算法、数据处理等",
"category": "开发者工具",
"implementation_type": "code",
"is_public": True,
"function_schema": {
"name": "execute_code",
"description": "执行Python代码",
"parameters": {
"type": "object",
"properties": {
"code": {"type": "string", "description": "要执行的Python代码"},
"timeout": {"type": "integer", "description": "超时秒数", "default": 10}
},
"required": ["code"]
}
},
"implementation_config": {
"source": """def run(args):
import sys, io, contextlib, json, traceback, time
code = args.get("code", "")
timeout = int(args.get("timeout", 10))
stdout_capture = io.StringIO()
stderr_capture = io.StringIO()
result = {"stdout": "", "stderr": "", "error": None}
namespace = {}
start = time.time()
try:
with contextlib.redirect_stdout(stdout_capture), contextlib.redirect_stderr(stderr_capture):
exec(code, namespace)
result["stdout"] = stdout_capture.getvalue()
result["stderr"] = stderr_capture.getvalue()
if "__returns__" in namespace:
result["result"] = str(namespace["__returns__"])
for key in ["result", "output", "ret"]:
if key in namespace and key not in ("result",):
result["result"] = str(namespace[key])
break
except Exception as e:
result["error"] = traceback.format_exc()
result["stderr"] = stderr_capture.getvalue()
result["elapsed_ms"] = int((time.time() - start) * 1000)
return result"""
}
},
{
"name": "grep_search",
"description": "在项目文件中搜索文本,支持正则表达式和通配符过滤。类似 grep 命令",
"category": "开发者工具",
"implementation_type": "code",
"is_public": True,
"function_schema": {
"name": "grep_search",
"description": "搜索项目文件中的文本",
"parameters": {
"type": "object",
"properties": {
"pattern": {"type": "string", "description": "搜索模式(支持正则)"},
"file_pattern": {"type": "string", "description": "文件通配符过滤,如 *.py, *.ts, *.vue", "default": "*"},
"path": {"type": "string", "description": "搜索路径相对于项目根默认backend", "default": "."},
"max_results": {"type": "integer", "description": "最大结果数", "default": 20}
},
"required": ["pattern"]
}
},
"implementation_config": {
"source": """def run(args):
import os, re, fnmatch
pattern = args.get("pattern", "")
file_pattern = args.get("file_pattern", "*")
root = args.get("path", ".")
max_results = int(args.get("max_results", 20))
results = []
errors = []
try:
for dirpath, dirnames, filenames in os.walk(root):
dirnames[:] = [d for d in dirnames if not d.startswith(".") and d != "node_modules" and d != "__pycache__"]
for f in sorted(filenames):
if not fnmatch.fnmatch(f, file_pattern):
continue
fpath = os.path.join(dirpath, f)
try:
with open(fpath, "r", encoding="utf-8", errors="replace") as fh:
for lineno, line in enumerate(fh, 1):
if re.search(pattern, line):
relpath = os.path.relpath(fpath, root)
results.append(f"{relpath}:{lineno}:{line.rstrip()[:200]}")
if len(results) >= max_results:
break
except Exception as e:
errors.append(str(e))
if len(results) >= max_results:
break
if len(results) >= max_results:
break
except Exception as e:
errors.append(str(e))
return {"results": results, "count": len(results), "errors": errors[:5], "truncated": len(results) >= max_results}"""
}
},
{
"name": "list_files",
"description": "列出项目目录中的文件和子目录,支持递归和过滤",
"category": "开发者工具",
"implementation_type": "code",
"is_public": True,
"function_schema": {
"name": "list_files",
"description": "列出目录文件",
"parameters": {
"type": "object",
"properties": {
"path": {"type": "string", "description": "目录路径(相对于项目根)", "default": "."},
"recursive": {"type": "boolean", "description": "是否递归", "default": False},
"max_depth": {"type": "integer", "description": "递归最大深度", "default": 2}
},
"required": []
}
},
"implementation_config": {
"source": """def run(args):
import os
path = args.get("path", ".")
recursive = bool(args.get("recursive", False))
max_depth = int(args.get("max_depth", 2))
skip_dirs = {".git", "node_modules", "__pycache__", ".venv", ".claude", "dist", ".vite"}
skip_ext = {".pyc", ".pyo"}
def _walk(dirpath, depth=0):
items = []
try:
for name in sorted(os.listdir(dirpath)):
if name.startswith("."):
continue
full = os.path.join(dirpath, name)
rel = os.path.relpath(full, path)
is_dir = os.path.isdir(full)
if is_dir and name in skip_dirs:
continue
size = ""
if not is_dir:
try: size = os.path.getsize(full)
except: size = 0
if size and size > 1024:
size = f"{size/1024:.1f}KB"
elif size:
size = f"{size}B"
ext = os.path.splitext(name)[1]
if ext in skip_ext:
continue
items.append({
"name": rel.replace("\\\\", "/"),
"type": "dir" if is_dir else "file",
"size": size,
})
if is_dir and recursive and depth < max_depth:
items.extend(_walk(full, depth + 1))
except Exception as e:
items.append({"name": f"[error: {e}]", "type": "error"})
return items
all_items = _walk(path)
dirs = [i for i in all_items if i["type"] == "dir"]
files = [i for i in all_items if i["type"] == "file"]
return {"path": path, "directories": len(dirs), "files": len(files), "items": dirs + files}"""
}
},
{
"name": "git_log",
"description": "查看Git提交历史获取最近更改记录",
"category": "开发者工具",
"implementation_type": "code",
"is_public": True,
"function_schema": {
"name": "git_log",
"description": "查看Git提交历史",
"parameters": {
"type": "object",
"properties": {
"count": {"type": "integer", "description": "最近提交数", "default": 10},
"path": {"type": "string", "description": "查看特定文件的提交历史(可选)"}
},
"required": []
}
},
"implementation_config": {
"source": """def run(args):
import subprocess, os
count = int(args.get("count", 10))
file_path = args.get("path")
cmds = ["git", "log", f"--max-count={count}", "--pretty=format:%h|%an|%ad|%s", "--date=short"]
if file_path:
cmds.append("--")
cmds.append(file_path)
try:
result = subprocess.run(cmds, capture_output=True, text=True, timeout=15)
if result.returncode != 0:
return {"error": result.stderr[:500], "commits": []}
commits = []
for line in result.stdout.strip().split("\\n"):
if not line: continue
parts = line.split("|", 3)
if len(parts) == 4:
commits.append({"hash": parts[0], "author": parts[1], "date": parts[2], "message": parts[3]})
return {"commits": commits, "count": len(commits)}
except Exception as e:
return {"error": str(e), "commits": []}"""
}
},
]
created = 0
failed = 0
for t in dev_tools:
status, data = req("POST", "/api/v1/tools", headers=auth, body=t)
if status == 201:
print(f" OK {t['name']}")
created += 1
elif status == 400 and "已存在" in str(data.get("detail", "")):
print(f" - {t['name']} (already exists)")
created += 1
else:
print(f" FAIL {t['name']}: {data.get('detail', data)}")
failed += 1
print(f"Tools created: {created} ok, {failed} failed")
# 3. 创建编程助手 Agent (with workflow config)
_start_id = str(uuid.uuid4())
_llm_id = str(uuid.uuid4())
_end_id = str(uuid.uuid4())
agent_config = {
"name": "代码编程助手",
"description": "专业的代码编程助手,能够理解项目结构、搜索代码、执行和测试代码",
"workflow_config": {
"nodes": [
{
"id": _start_id,
"type": "start",
"position": {"x": 100, "y": 200},
"data": {"label": "开始"},
},
{
"id": _llm_id,
"type": "llm",
"position": {"x": 350, "y": 200},
"data": {
"label": "代码编程助手",
"system_prompt": (
"你是代码编程助手 CodeBot一个专业的软件工程AI助手。\n\n"
"## 核心能力\n"
"你擅长阅读、理解、编写和调试代码。你可以使用各种工具来帮助用户完成编程任务。\n\n"
"## 可用工具\n"
"- **file_read**: 读取项目文件\n"
"- **file_write**: 写入/修改文件\n"
"- **execute_code**: 在沙箱中执行Python代码快速验证逻辑\n"
"- **grep_search**: 在项目中搜索代码\n"
"- **list_files**: 浏览项目目录结构\n"
"- **git_log**: 查看Git提交历史\n"
"- **http_request**: 发送HTTP请求\n"
"- **text_analyze**: 文本分析\n"
"- **json_process**: JSON处理\n\n"
"## 工作流程\n"
"1. 先理解用户需求,必要时浏览项目结构了解代码组织\n"
"2. 搜索相关代码定位需要修改或参考的位置\n"
"3. 阅读相关文件完整理解上下文\n"
"4. 编写或修改代码\n"
"5. 使用 execute_code 测试代码逻辑\n"
"6. 向用户解释修改的内容和原因\n\n"
"## 回答风格\n"
"- 清晰、准确、有逻辑\n"
"- 展示代码时添加适当注释\n"
"- 解释代码的原理和设计思路\n"
"- 如果存在多种方案,对比优缺点\n"
"- 指出潜在的风险和注意事项\n\n"
"## 边界\n"
"若用户问「你有什么能力」「你能做什么」,只介绍与编程、软件工程及上文工具相关的能力;"
"不要列举写诗、泛泛日常助手等与编程无关的能力。"
),
"model": "deepseek-v4-flash",
"provider": "deepseek",
"temperature": 0.3,
"max_iterations": 30,
"tools": [
"file_read", "file_write", "execute_code", "grep_search",
"list_files", "git_log", "http_request", "text_analyze", "json_process",
],
"memory": True,
},
},
{
"id": _end_id,
"type": "end",
"position": {"x": 600, "y": 200},
"data": {"label": "结束"},
},
],
"edges": [
{"id": str(uuid.uuid4()), "source": _start_id, "target": _llm_id},
{"id": str(uuid.uuid4()), "source": _llm_id, "target": _end_id},
],
},
"budget_config": {
"max_llm_invocations": 100,
"max_tool_calls": 200,
},
}
status, data = req("POST", "/api/v1/agents", headers=auth, body=agent_config)
if status in (200, 201):
agent_id = data.get("id", "")
print(f"\nOK Agent created: {agent_id}")
else:
print(f"\nFAIL Agent creation: {data}")
# Try to update existing agent
print("Checking existing agents...")
s, agents = req("GET", "/api/v1/agents", headers=auth)
if s == 200 and isinstance(agents, list):
for a in agents:
if a.get("name") == "代码编程助手":
print(f" Already exists: {a.get('id')}")
print("\nDone! Go to Agent Management -> Code Programming Assistant -> Chat to start using it.")