Feat/dataset service api (#1245)
Co-authored-by: jyong <jyong@dify.ai> Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
This commit is contained in:
@@ -96,7 +96,7 @@ class DatasetService:
|
||||
embedding_model = None
|
||||
if indexing_technique == 'high_quality':
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=current_user.current_tenant_id
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
dataset = Dataset(name=name, indexing_technique=indexing_technique)
|
||||
# dataset = Dataset(name=name, provider=provider, config=config)
|
||||
@@ -477,6 +477,7 @@ class DocumentService:
|
||||
)
|
||||
dataset.collection_binding_id = dataset_collection_binding.id
|
||||
|
||||
|
||||
documents = []
|
||||
batch = time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999))
|
||||
if 'original_document_id' in document_data and document_data["original_document_id"]:
|
||||
@@ -626,6 +627,9 @@ class DocumentService:
|
||||
document = DocumentService.get_document(dataset.id, document_data["original_document_id"])
|
||||
if document.display_status != 'available':
|
||||
raise ValueError("Document is not available")
|
||||
# update document name
|
||||
if 'name' in document_data and document_data['name']:
|
||||
document.name = document_data['name']
|
||||
# save process rule
|
||||
if 'process_rule' in document_data and document_data['process_rule']:
|
||||
process_rule = document_data["process_rule"]
|
||||
@@ -767,7 +771,7 @@ class DocumentService:
|
||||
return dataset, documents, batch
|
||||
|
||||
@classmethod
|
||||
def document_create_args_validate(cls, args: dict):
|
||||
def document_create_args_validate(cls, args: dict):
|
||||
if 'original_document_id' not in args or not args['original_document_id']:
|
||||
DocumentService.data_source_args_validate(args)
|
||||
DocumentService.process_rule_args_validate(args)
|
||||
@@ -1014,6 +1018,66 @@ class SegmentService:
|
||||
segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment_document.id).first()
|
||||
return segment
|
||||
|
||||
@classmethod
|
||||
def multi_create_segment(cls, segments: list, document: Document, dataset: Dataset):
|
||||
embedding_model = None
|
||||
if dataset.indexing_technique == 'high_quality':
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
max_position = db.session.query(func.max(DocumentSegment.position)).filter(
|
||||
DocumentSegment.document_id == document.id
|
||||
).scalar()
|
||||
pre_segment_data_list = []
|
||||
segment_data_list = []
|
||||
for segment_item in segments:
|
||||
content = segment_item['content']
|
||||
doc_id = str(uuid.uuid4())
|
||||
segment_hash = helper.generate_text_hash(content)
|
||||
tokens = 0
|
||||
if dataset.indexing_technique == 'high_quality' and embedding_model:
|
||||
# calc embedding use tokens
|
||||
tokens = embedding_model.get_num_tokens(content)
|
||||
segment_document = DocumentSegment(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
dataset_id=document.dataset_id,
|
||||
document_id=document.id,
|
||||
index_node_id=doc_id,
|
||||
index_node_hash=segment_hash,
|
||||
position=max_position + 1 if max_position else 1,
|
||||
content=content,
|
||||
word_count=len(content),
|
||||
tokens=tokens,
|
||||
status='completed',
|
||||
indexing_at=datetime.datetime.utcnow(),
|
||||
completed_at=datetime.datetime.utcnow(),
|
||||
created_by=current_user.id
|
||||
)
|
||||
if document.doc_form == 'qa_model':
|
||||
segment_document.answer = segment_item['answer']
|
||||
db.session.add(segment_document)
|
||||
segment_data_list.append(segment_document)
|
||||
pre_segment_data = {
|
||||
'segment': segment_document,
|
||||
'keywords': segment_item['keywords']
|
||||
}
|
||||
pre_segment_data_list.append(pre_segment_data)
|
||||
|
||||
try:
|
||||
# save vector index
|
||||
VectorService.multi_create_segment_vector(pre_segment_data_list, dataset)
|
||||
except Exception as e:
|
||||
logging.exception("create segment index failed")
|
||||
for segment_document in segment_data_list:
|
||||
segment_document.enabled = False
|
||||
segment_document.disabled_at = datetime.datetime.utcnow()
|
||||
segment_document.status = 'error'
|
||||
segment_document.error = str(e)
|
||||
db.session.commit()
|
||||
return segment_data_list
|
||||
|
||||
@classmethod
|
||||
def update_segment(cls, args: dict, segment: DocumentSegment, document: Document, dataset: Dataset):
|
||||
indexing_cache_key = 'segment_{}_indexing'.format(segment.id)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
__all__ = [
|
||||
'base', 'conversation', 'message', 'index', 'app_model_config', 'account', 'document', 'dataset',
|
||||
'app', 'completion', 'audio'
|
||||
'app', 'completion', 'audio', 'file'
|
||||
]
|
||||
|
||||
from . import *
|
||||
|
||||
@@ -3,3 +3,11 @@ from services.errors.base import BaseServiceError
|
||||
|
||||
class FileNotExistsError(BaseServiceError):
|
||||
pass
|
||||
|
||||
|
||||
class FileTooLargeError(BaseServiceError):
|
||||
description = "{message}"
|
||||
|
||||
|
||||
class UnsupportedFileTypeError(BaseServiceError):
|
||||
pass
|
||||
|
||||
123
api/services/file_service.py
Normal file
123
api/services/file_service.py
Normal file
@@ -0,0 +1,123 @@
|
||||
import datetime
|
||||
import hashlib
|
||||
import time
|
||||
import uuid
|
||||
|
||||
from cachetools import TTLCache
|
||||
from flask import request, current_app
|
||||
from flask_login import current_user
|
||||
from werkzeug.datastructures import FileStorage
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from core.data_loader.file_extractor import FileExtractor
|
||||
from extensions.ext_storage import storage
|
||||
from extensions.ext_database import db
|
||||
from models.model import UploadFile
|
||||
from services.errors.file import FileTooLargeError, UnsupportedFileTypeError
|
||||
|
||||
ALLOWED_EXTENSIONS = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm', 'xlsx', 'docx', 'csv']
|
||||
PREVIEW_WORDS_LIMIT = 3000
|
||||
cache = TTLCache(maxsize=None, ttl=30)
|
||||
|
||||
|
||||
class FileService:
|
||||
|
||||
@staticmethod
|
||||
def upload_file(file: FileStorage) -> UploadFile:
|
||||
# read file content
|
||||
file_content = file.read()
|
||||
# get file size
|
||||
file_size = len(file_content)
|
||||
|
||||
file_size_limit = current_app.config.get("UPLOAD_FILE_SIZE_LIMIT") * 1024 * 1024
|
||||
if file_size > file_size_limit:
|
||||
message = f'File size exceeded. {file_size} > {file_size_limit}'
|
||||
raise FileTooLargeError(message)
|
||||
|
||||
extension = file.filename.split('.')[-1]
|
||||
if extension.lower() not in ALLOWED_EXTENSIONS:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
# user uuid as file name
|
||||
file_uuid = str(uuid.uuid4())
|
||||
file_key = 'upload_files/' + current_user.current_tenant_id + '/' + file_uuid + '.' + extension
|
||||
|
||||
# save file to storage
|
||||
storage.save(file_key, file_content)
|
||||
|
||||
# save file to db
|
||||
config = current_app.config
|
||||
upload_file = UploadFile(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
storage_type=config['STORAGE_TYPE'],
|
||||
key=file_key,
|
||||
name=file.filename,
|
||||
size=file_size,
|
||||
extension=extension,
|
||||
mime_type=file.mimetype,
|
||||
created_by=current_user.id,
|
||||
created_at=datetime.datetime.utcnow(),
|
||||
used=False,
|
||||
hash=hashlib.sha3_256(file_content).hexdigest()
|
||||
)
|
||||
|
||||
db.session.add(upload_file)
|
||||
db.session.commit()
|
||||
|
||||
return upload_file
|
||||
|
||||
@staticmethod
|
||||
def upload_text(text: str, text_name: str) -> UploadFile:
|
||||
# user uuid as file name
|
||||
file_uuid = str(uuid.uuid4())
|
||||
file_key = 'upload_files/' + current_user.current_tenant_id + '/' + file_uuid + '.txt'
|
||||
|
||||
# save file to storage
|
||||
storage.save(file_key, text.encode('utf-8'))
|
||||
|
||||
# save file to db
|
||||
config = current_app.config
|
||||
upload_file = UploadFile(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
storage_type=config['STORAGE_TYPE'],
|
||||
key=file_key,
|
||||
name=text_name + '.txt',
|
||||
size=len(text),
|
||||
extension='txt',
|
||||
mime_type='text/plain',
|
||||
created_by=current_user.id,
|
||||
created_at=datetime.datetime.utcnow(),
|
||||
used=True,
|
||||
used_by=current_user.id,
|
||||
used_at=datetime.datetime.utcnow()
|
||||
)
|
||||
|
||||
db.session.add(upload_file)
|
||||
db.session.commit()
|
||||
|
||||
return upload_file
|
||||
|
||||
@staticmethod
|
||||
def get_file_preview(file_id: str) -> str:
|
||||
# get file storage key
|
||||
key = file_id + request.path
|
||||
cached_response = cache.get(key)
|
||||
if cached_response and time.time() - cached_response['timestamp'] < cache.ttl:
|
||||
return cached_response['response']
|
||||
|
||||
upload_file = db.session.query(UploadFile) \
|
||||
.filter(UploadFile.id == file_id) \
|
||||
.first()
|
||||
|
||||
if not upload_file:
|
||||
raise NotFound("File not found")
|
||||
|
||||
# extract text from file
|
||||
extension = upload_file.extension
|
||||
if extension.lower() not in ALLOWED_EXTENSIONS:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
text = FileExtractor.load(upload_file, return_text=True)
|
||||
text = text[0:PREVIEW_WORDS_LIMIT] if text else ''
|
||||
|
||||
return text
|
||||
@@ -35,6 +35,32 @@ class VectorService:
|
||||
else:
|
||||
index.add_texts([document])
|
||||
|
||||
@classmethod
|
||||
def multi_create_segment_vector(cls, pre_segment_data_list: list, dataset: Dataset):
|
||||
documents = []
|
||||
for pre_segment_data in pre_segment_data_list:
|
||||
segment = pre_segment_data['segment']
|
||||
document = Document(
|
||||
page_content=segment.content,
|
||||
metadata={
|
||||
"doc_id": segment.index_node_id,
|
||||
"doc_hash": segment.index_node_hash,
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
}
|
||||
)
|
||||
documents.append(document)
|
||||
|
||||
# save vector index
|
||||
index = IndexBuilder.get_index(dataset, 'high_quality')
|
||||
if index:
|
||||
index.add_texts(documents, duplicate_check=True)
|
||||
|
||||
# save keyword index
|
||||
keyword_index = IndexBuilder.get_index(dataset, 'economy')
|
||||
if keyword_index:
|
||||
keyword_index.multi_create_segment_keywords(pre_segment_data_list)
|
||||
|
||||
@classmethod
|
||||
def update_segment_vector(cls, keywords: Optional[List[str]], segment: DocumentSegment, dataset: Dataset):
|
||||
# update segment index task
|
||||
|
||||
Reference in New Issue
Block a user