chore(api/core): Improve FileVar's type hint and imports. (#7290)

This commit is contained in:
-LAN-
2024-08-15 12:43:18 +08:00
committed by GitHub
parent 6ff7fd80a1
commit 8f16165f92
7 changed files with 68 additions and 59 deletions

View File

@@ -1,11 +1,10 @@
import enum
import json
import os
from typing import Optional
from typing import TYPE_CHECKING, Optional
from core.app.app_config.entities import PromptTemplateEntity
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.file.file_obj import FileVar
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.entities.message_entities import (
PromptMessage,
@@ -18,6 +17,9 @@ from core.prompt.prompt_transform import PromptTransform
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from models.model import AppMode
if TYPE_CHECKING:
from core.file.file_obj import FileVar
class ModelMode(enum.Enum):
COMPLETION = 'completion'
@@ -50,7 +52,7 @@ class SimplePromptTransform(PromptTransform):
prompt_template_entity: PromptTemplateEntity,
inputs: dict,
query: str,
files: list[FileVar],
files: list["FileVar"],
context: Optional[str],
memory: Optional[TokenBufferMemory],
model_config: ModelConfigWithCredentialsEntity) -> \
@@ -163,7 +165,7 @@ class SimplePromptTransform(PromptTransform):
inputs: dict,
query: str,
context: Optional[str],
files: list[FileVar],
files: list["FileVar"],
memory: Optional[TokenBufferMemory],
model_config: ModelConfigWithCredentialsEntity) \
-> tuple[list[PromptMessage], Optional[list[str]]]:
@@ -206,7 +208,7 @@ class SimplePromptTransform(PromptTransform):
inputs: dict,
query: str,
context: Optional[str],
files: list[FileVar],
files: list["FileVar"],
memory: Optional[TokenBufferMemory],
model_config: ModelConfigWithCredentialsEntity) \
-> tuple[list[PromptMessage], Optional[list[str]]]:
@@ -255,7 +257,7 @@ class SimplePromptTransform(PromptTransform):
return [self.get_last_user_message(prompt, files)], stops
def get_last_user_message(self, prompt: str, files: list[FileVar]) -> UserPromptMessage:
def get_last_user_message(self, prompt: str, files: list["FileVar"]) -> UserPromptMessage:
if files:
prompt_message_contents = [TextPromptMessageContent(data=prompt)]
for file in files: