""" 项目源码扫描器 — 程序化读取项目关键源文件,返回格式化的 Markdown。 可用于: - team_orchestrator 架构师预扫描(自动注入 prompt) - project_scan 工具(Agent 手动调用) - 单独的 API 端点 """ from __future__ import annotations import logging from pathlib import Path from typing import Dict, List, Optional logger = logging.getLogger(__name__) # ═══════════════════════════════════════════════ # 项目类型识别规则 # ═══════════════════════════════════════════════ _PROJECT_TYPE_RULES: Dict[str, tuple] = { "build.gradle.kts": ("android_kotlin", [ "build.gradle.kts", "settings.gradle.kts", "app/build.gradle.kts", "app/src/main/AndroidManifest.xml", ]), "build.gradle": ("android_java", [ "build.gradle", "settings.gradle", "app/build.gradle", "app/src/main/AndroidManifest.xml", ]), "package.json": ("frontend_js", [ "package.json", "vite.config.ts", "vite.config.js", "next.config.js", "nuxt.config.js", ]), "requirements.txt": ("python", [ "requirements.txt", "pyproject.toml", "setup.py", ]), "pyproject.toml": ("python", [ "pyproject.toml", "requirements.txt", ]), "go.mod": ("go", [ "go.mod", "go.sum", "main.go", ]), } # 各项目类型的深度扫描文件模式 _SCAN_DEPTH_PATTERNS: Dict[str, List[str]] = { "android_kotlin": [ "app/src/main/java/**/*ViewModel.kt", "app/src/main/java/**/*Screen.kt", "app/src/main/java/**/*Repository.kt", "app/src/main/java/**/di/*.kt", "app/src/main/java/**/data/remote/*.kt", "app/src/main/java/**/data/local/*.kt", "app/src/main/java/**/ui/**/*.kt", "app/src/main/AndroidManifest.xml", ], "android_java": [ "app/src/main/java/**/*ViewModel.java", "app/src/main/java/**/*Activity.java", "app/src/main/java/**/*Fragment.java", "app/src/main/java/**/di/*.java", "app/src/main/java/**/data/remote/*.java", "app/src/main/AndroidManifest.xml", ], "frontend_js": [ "src/**/*.vue", "src/**/*.tsx", "src/**/*.jsx", "src/router/**/*.ts", "src/router/**/*.js", "src/stores/**/*.ts", "src/stores/**/*.js", "src/api/**/*.ts", "src/api/**/*.js", "src/App.vue", "src/App.tsx", "src/main.ts", "src/main.js", "src/index.tsx", ], "python": [ "app/**/*.py", "api/**/*.py", "routers/**/*.py", "models/**/*.py", "services/**/*.py", "schemas/**/*.py", "main.py", "app.py", ], "go": [ "cmd/**/*.go", "internal/**/*.go", "pkg/**/*.go", "main.go", "handlers/**/*.go", "models/**/*.go", ], } # 每种类型读取的最大数量和大小限制 _MAX_SOURCE_FILES = 12 _MAX_SOURCE_BYTES = 50000 _MAX_BYTES_PER_FILE = 8000 # 文件优先级排序关键字 _PRIORITY_ORDER = [ 'ViewModel', 'Screen', 'Repository', 'Service', 'Dao', 'Client', 'Interceptor', 'Module', 'Activity', 'Application', 'Main', ] def identify_project_type(dir_path: Path) -> Optional[str]: """通过检查目录中的标志文件识别项目类型。""" for marker, (ptype, _) in _PROJECT_TYPE_RULES.items(): if (dir_path / marker).exists(): return ptype # 只搜索一层子目录 for child in dir_path.iterdir(): if child.is_dir() and not child.name.startswith('.'): if (child / marker).exists(): return ptype return None def scan_source_code(target_dir: Path, max_files: int = _MAX_SOURCE_FILES, max_bytes: int = _MAX_SOURCE_BYTES, max_per_file: int = _MAX_BYTES_PER_FILE) -> str: """扫描项目目录,读取关键源文件内容,返回格式化的 Markdown 字符串。 Args: target_dir: 项目根目录路径 max_files: 最多读取的文件数量 max_bytes: 总共最多读取的字节数 max_per_file: 单文件最多读取的字节数 Returns: 格式化的 Markdown 字符串,可直接嵌入 prompt """ if not target_dir.is_dir(): return "" ptype = identify_project_type(target_dir) parts: list[str] = [] total_bytes = 0 files_read = 0 if not ptype: # 未识别到项目类型,至少列出目录结构 try: top_items = sorted( [p for p in target_dir.iterdir() if not p.name.startswith('.')], key=lambda p: (p.is_file(), p.name) )[:30] listing = "\n".join( f"- {p.name}{'/' if p.is_dir() else ''}" for p in top_items ) if listing: parts.append(f"### 目录: `{target_dir}`\n" f"(项目类型未识别,目录列表)\n\n{listing}\n") except (PermissionError, OSError): pass return "\n".join(parts) if parts else "" # 收集要读取的文件 patterns = _SCAN_DEPTH_PATTERNS.get(ptype, []) files_to_read: list[Path] = [] seen: set[Path] = set() for pat in patterns: try: for match in sorted(target_dir.glob(pat)): if match.is_file() and match.suffix in ( '.kt', '.java', '.py', '.ts', '.js', '.vue', '.tsx', '.jsx', '.go', '.xml', '.kts', '.gradle', '.json', '.yaml', '.yml', '.toml', ): if match not in seen: seen.add(match) files_to_read.append(match) except (PermissionError, OSError): continue if not files_to_read: parts.append(f"### 项目: `{target_dir}` (类型: {ptype})\n" f"(未找到匹配的源文件)\n") return "\n".join(parts) if parts else "" # 按优先级排序 def sort_key(filepath: Path) -> int: name = filepath.stem for i, kw in enumerate(_PRIORITY_ORDER): if kw.lower() in name.lower(): return i return len(_PRIORITY_ORDER) files_to_read.sort(key=sort_key) # 读取文件(受总量限制) module_parts = [] module_bytes = 0 module_count = 0 for fp in files_to_read: if files_read >= max_files or total_bytes >= max_bytes: break try: content = fp.read_text(encoding='utf-8', errors='replace') if len(content) > max_per_file: half = max_per_file // 2 content = (content[:half] + f"\n\n... (省略 {len(content) - max_per_file} 字节) ...\n\n" + content[-half:]) rel_path = fp.relative_to(target_dir) if target_dir in fp.parents else fp lang = _get_code_lang(ptype) module_parts.append( f"#### `{rel_path}` ({len(content)} bytes)\n" f"```{lang}\n{content}\n```\n" ) module_bytes += len(content) module_count += 1 total_bytes += len(content) files_read += 1 except (PermissionError, OSError, UnicodeDecodeError): continue if module_parts: parts.insert(0, f"### 项目源码: `{target_dir}` (类型: {ptype}, " f"读取 {module_count} 文件, {module_bytes // 1024}KB)\n") parts.extend(module_parts) if files_read >= max_files or total_bytes >= max_bytes: parts.append(f"\n> ⚠️ 已达源码扫描上限 " f"({files_read} 文件, {total_bytes // 1024}KB)。" f"剩余文件未读取。\n") if not parts: return "" header = ("## 🔍 预扫描源码清单\n\n" "以下是通过程序化扫描自动读取的项目源代码文件内容。" "请基于这些真实代码进行分析,不要依赖设计文档或推测。\n\n") return header + "\n".join(parts) def scan_multiple_dirs(dir_paths: list[Path], max_files: int = _MAX_SOURCE_FILES, max_bytes: int = _MAX_SOURCE_BYTES, max_per_file: int = _MAX_BYTES_PER_FILE) -> str: """扫描多个目录,返回合并的 Markdown 字符串。""" parts: list[str] = [] total_files = 0 total_bytes = 0 for d in dir_paths: result = scan_source_code(d, max_files - total_files, max_bytes - total_bytes, max_per_file) if result: parts.append(result) total_files = min(total_files + _MAX_SOURCE_FILES, max_files) total_bytes = min(total_bytes + _MAX_SOURCE_BYTES, max_bytes) return "\n".join(parts) if parts else "" def _get_code_lang(ptype: str) -> str: """根据项目类型返回代码块的语言标识。""" lang_map = { "android_kotlin": "kotlin", "android_java": "java", "frontend_js": "typescript", "python": "python", "go": "go", } return lang_map.get(ptype, "") # 导出公共 API __all__ = [ "identify_project_type", "scan_source_code", "scan_multiple_dirs", ]