chore: remove Langchain tools import (#3407)

This commit is contained in:
Jyong
2024-04-12 16:26:09 +08:00
committed by GitHub
parent c227f3d985
commit 0737e930cb
9 changed files with 98 additions and 73 deletions

View File

@@ -1,8 +1,6 @@
import threading
from typing import Optional
from flask import Flask, current_app
from langchain.tools import BaseTool
from pydantic import BaseModel, Field
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
@@ -10,6 +8,7 @@ from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.rag.datasource.retrieval_service import RetrievalService
from core.rerank.rerank import RerankRunner
from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
from extensions.ext_database import db
from models.dataset import Dataset, Document, DocumentSegment
@@ -29,20 +28,15 @@ class DatasetMultiRetrieverToolInput(BaseModel):
query: str = Field(..., description="dataset multi retriever and rerank")
class DatasetMultiRetrieverTool(BaseTool):
class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
"""Tool for querying multi dataset."""
name: str = "dataset_"
args_schema: type[BaseModel] = DatasetMultiRetrieverToolInput
description: str = "dataset multi retriever and rerank. "
tenant_id: str
dataset_ids: list[str]
top_k: int = 2
score_threshold: Optional[float] = None
reranking_provider_name: str
reranking_model_name: str
return_resource: bool
retriever_from: str
hit_callbacks: list[DatasetIndexToolCallbackHandler] = []
@classmethod
def from_dataset(cls, dataset_ids: list[str], tenant_id: str, **kwargs):
@@ -149,9 +143,6 @@ class DatasetMultiRetrieverTool(BaseTool):
return str("\n".join(document_context_list))
async def _arun(self, tool_input: str) -> str:
raise NotImplementedError()
def _retriever(self, flask_app: Flask, dataset_id: str, query: str, all_documents: list,
hit_callbacks: list[DatasetIndexToolCallbackHandler]):
with flask_app.app_context():

View File

@@ -0,0 +1,34 @@
from abc import abstractmethod
from typing import Any, Optional
from msal_extensions.persistence import ABC
from pydantic import BaseModel
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
class DatasetRetrieverBaseTool(BaseModel, ABC):
"""Tool for querying a Dataset."""
name: str = "dataset"
description: str = "use this to retrieve a dataset. "
tenant_id: str
top_k: int = 2
score_threshold: Optional[float] = None
hit_callbacks: list[DatasetIndexToolCallbackHandler] = []
return_resource: bool
retriever_from: str
class Config:
arbitrary_types_allowed = True
@abstractmethod
def _run(
self,
*args: Any,
**kwargs: Any,
) -> Any:
"""Use the tool.
Add run_manager: Optional[CallbackManagerForToolRun] = None
to child implementations to enable tracing,
"""

View File

@@ -1,10 +1,8 @@
from typing import Optional
from langchain.tools import BaseTool
from pydantic import BaseModel, Field
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.rag.datasource.retrieval_service import RetrievalService
from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
from extensions.ext_database import db
from models.dataset import Dataset, Document, DocumentSegment
@@ -24,19 +22,13 @@ class DatasetRetrieverToolInput(BaseModel):
query: str = Field(..., description="Query for the dataset to be used to retrieve the dataset.")
class DatasetRetrieverTool(BaseTool):
class DatasetRetrieverTool(DatasetRetrieverBaseTool):
"""Tool for querying a Dataset."""
name: str = "dataset"
args_schema: type[BaseModel] = DatasetRetrieverToolInput
description: str = "use this to retrieve a dataset. "
tenant_id: str
dataset_id: str
top_k: int = 2
score_threshold: Optional[float] = None
hit_callbacks: list[DatasetIndexToolCallbackHandler] = []
return_resource: bool
retriever_from: str
@classmethod
def from_dataset(cls, dataset: Dataset, **kwargs):
@@ -153,7 +145,4 @@ class DatasetRetrieverTool(BaseTool):
for hit_callback in self.hit_callbacks:
hit_callback.return_retriever_resource_info(context_list)
return str("\n".join(document_context_list))
async def _arun(self, tool_input: str) -> str:
raise NotImplementedError()
return str("\n".join(document_context_list))