feat: integrate flask-orjson for improved JSON serialization performance (#23935)

This commit is contained in:
-LAN-
2025-08-14 19:50:59 +08:00
committed by GitHub
parent 4a2e6af9b5
commit e340fccafb
8 changed files with 64 additions and 27 deletions

View File

@@ -5,7 +5,7 @@ from base64 import b64encode
from collections.abc import Mapping
from typing import Any
from core.variables.utils import SegmentJSONEncoder
from core.variables.utils import dumps_with_segments
class TemplateTransformer(ABC):
@@ -93,7 +93,7 @@ class TemplateTransformer(ABC):
@classmethod
def serialize_inputs(cls, inputs: Mapping[str, Any]) -> str:
inputs_json_str = json.dumps(inputs, ensure_ascii=False, cls=SegmentJSONEncoder).encode()
inputs_json_str = dumps_with_segments(inputs, ensure_ascii=False).encode()
input_base64_encoded = b64encode(inputs_json_str).decode("utf-8")
return input_base64_encoded

View File

@@ -1,7 +1,7 @@
import json
from collections import defaultdict
from typing import Any, Optional
import orjson
from pydantic import BaseModel
from configs import dify_config
@@ -134,13 +134,13 @@ class Jieba(BaseKeyword):
dataset_keyword_table = self.dataset.dataset_keyword_table
keyword_data_source_type = dataset_keyword_table.data_source_type
if keyword_data_source_type == "database":
dataset_keyword_table.keyword_table = json.dumps(keyword_table_dict, cls=SetEncoder)
dataset_keyword_table.keyword_table = dumps_with_sets(keyword_table_dict)
db.session.commit()
else:
file_key = "keyword_files/" + self.dataset.tenant_id + "/" + self.dataset.id + ".txt"
if storage.exists(file_key):
storage.delete(file_key)
storage.save(file_key, json.dumps(keyword_table_dict, cls=SetEncoder).encode("utf-8"))
storage.save(file_key, dumps_with_sets(keyword_table_dict).encode("utf-8"))
def _get_dataset_keyword_table(self) -> Optional[dict]:
dataset_keyword_table = self.dataset.dataset_keyword_table
@@ -156,12 +156,11 @@ class Jieba(BaseKeyword):
data_source_type=keyword_data_source_type,
)
if keyword_data_source_type == "database":
dataset_keyword_table.keyword_table = json.dumps(
dataset_keyword_table.keyword_table = dumps_with_sets(
{
"__type__": "keyword_table",
"__data__": {"index_id": self.dataset.id, "summary": None, "table": {}},
},
cls=SetEncoder,
}
)
db.session.add(dataset_keyword_table)
db.session.commit()
@@ -252,8 +251,13 @@ class Jieba(BaseKeyword):
self._save_dataset_keyword_table(keyword_table)
class SetEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, set):
return list(obj)
return super().default(obj)
def set_orjson_default(obj: Any) -> Any:
"""Default function for orjson serialization of set types"""
if isinstance(obj, set):
return list(obj)
raise TypeError(f"Object of type {type(obj).__name__} is not JSON serializable")
def dumps_with_sets(obj: Any) -> str:
"""JSON dumps with set support using orjson"""
return orjson.dumps(obj, default=set_orjson_default).decode("utf-8")

View File

@@ -1,5 +1,7 @@
import json
from collections.abc import Iterable, Sequence
from typing import Any
import orjson
from .segment_group import SegmentGroup
from .segments import ArrayFileSegment, FileSegment, Segment
@@ -12,15 +14,20 @@ def to_selector(node_id: str, name: str, paths: Iterable[str] = ()) -> Sequence[
return selectors
class SegmentJSONEncoder(json.JSONEncoder):
def default(self, o):
if isinstance(o, ArrayFileSegment):
return [v.model_dump() for v in o.value]
elif isinstance(o, FileSegment):
return o.value.model_dump()
elif isinstance(o, SegmentGroup):
return [self.default(seg) for seg in o.value]
elif isinstance(o, Segment):
return o.value
else:
super().default(o)
def segment_orjson_default(o: Any) -> Any:
"""Default function for orjson serialization of Segment types"""
if isinstance(o, ArrayFileSegment):
return [v.model_dump() for v in o.value]
elif isinstance(o, FileSegment):
return o.value.model_dump()
elif isinstance(o, SegmentGroup):
return [segment_orjson_default(seg) for seg in o.value]
elif isinstance(o, Segment):
return o.value
raise TypeError(f"Object of type {type(o).__name__} is not JSON serializable")
def dumps_with_segments(obj: Any, ensure_ascii: bool = False) -> str:
"""JSON dumps with segment support using orjson"""
option = orjson.OPT_NON_STR_KEYS
return orjson.dumps(obj, default=segment_orjson_default, option=option).decode("utf-8")