""" 文本分块器单元测试 """ from __future__ import annotations import pytest from app.services.text_chunker import chunk_text, _split_paragraphs, _split_long_paragraph, _merge_segments class TestSplitParagraphs: """段落分割测试""" def test_empty_text(self): assert _split_paragraphs("") == [] assert _split_paragraphs(" ") == [] assert _split_paragraphs("\n\n\n") == [] def test_single_paragraph(self): result = _split_paragraphs("Hello world") assert result == ["Hello world"] def test_multiple_paragraphs(self): text = "第一段内容。\n\n第二段内容。\n\n第三段内容。" result = _split_paragraphs(text) assert len(result) == 3 assert "第一段" in result[0] assert "第二段" in result[1] assert "第三段" in result[2] def test_mixed_newlines(self): text = "段1\n\n\n段2\n\n段3" result = _split_paragraphs(text) assert len(result) == 3 class TestSplitLongParagraph: """超长段落分割测试""" def test_short_paragraph(self): result = _split_long_paragraph("短文本", chunk_size=500) assert result == ["短文本"] def test_long_paragraph_chinese(self): para = "第一句。" * 200 # 600 chars, exceeds 500 result = _split_long_paragraph(para, chunk_size=500) assert len(result) >= 2 assert all(len(c) <= 500 for c in result) def test_long_paragraph_english(self): para = "Sentence. " * 200 result = _split_long_paragraph(para, chunk_size=500) assert len(result) >= 2 assert all(len(c) <= 500 for c in result) def test_no_sentence_boundary(self): """无句号可分割时按字符硬切""" para = "a" * 1000 result = _split_long_paragraph(para, chunk_size=500) assert len(result) == 2 assert len(result[0]) == 500 assert len(result[1]) == 500 class TestMergeSegments: """段落合并测试""" def test_empty(self): assert _merge_segments([], 500, 0) == [] def test_single_segment(self): result = _merge_segments(["hello"], 500, 0) assert result == ["hello"] def test_merge_short_segments(self): segs = ["a" * 100, "b" * 100, "c" * 100] # each < 500, total ~300 < 500 result = _merge_segments(segs, 500, 0) assert len(result) == 1 assert len(result[0]) > 200 def test_split_large_segments(self): segs = ["a" * 300, "b" * 300, "c" * 300] # 需要分为多个chunk result = _merge_segments(segs, 500, 0) assert len(result) >= 2 def test_overlap(self): segs = ["a" * 300, "b" * 300] result = _merge_segments(segs, 500, 50) assert len(result) >= 1 class TestChunkText: """chunk_text 整体测试""" def test_empty_text(self): assert chunk_text("") == [] assert chunk_text(" ") == [] assert chunk_text(None) == [] # type: ignore def test_short_text(self): result = chunk_text("Hello world", chunk_size=500, chunk_overlap=0) assert len(result) == 1 assert "Hello world" in result[0] def test_normal_text(self): text = """ 这是第一段。它包含一些内容。 这是第二段。它也有一些内容。而且更长一些。 这是第三段。最后一段内容。 """ result = chunk_text(text, chunk_size=500, chunk_overlap=0) assert len(result) >= 1 # 所有内容都应该在结果中 all_text = "".join(result) assert "第一段" in all_text assert "第二段" in all_text assert "第三段" in all_text def test_chinese_text(self): """中文文本测试""" text = "我喜欢吃川菜。特别是麻辣火锅。还有水煮鱼。这些都很美味。" result = chunk_text(text, chunk_size=100, chunk_overlap=0) assert len(result) >= 1 def test_overlap_between_chunks(self): """验证块间重叠""" para = "这是一个很长的段落。" * 50 result = chunk_text(para, chunk_size=200, chunk_overlap=50) if len(result) > 1: # 相邻块应该有重叠内容 assert len(result[0]) > 0 assert len(result[1]) > 0 @pytest.mark.unit def test_with_mixed_punctuation(self): text = "Hello! How are you? I am fine. Thank you. 你好!最近怎么样?我很好。谢谢。" result = chunk_text(text, chunk_size=200, chunk_overlap=0) assert len(result) >= 1