Revert "feat: knowledge admin role" (#6018)

This commit is contained in:
takatost
2024-07-05 21:31:34 +08:00
committed by GitHub
parent 71c50b7e20
commit 79df8825c8
46 changed files with 350 additions and 1028 deletions

View File

@@ -395,11 +395,6 @@ class DataSetConfig(BaseModel):
default=30,
)
DATASET_OPERATOR_ENABLED: bool = Field(
description='whether to enable dataset operator',
default=False,
)
class WorkspaceConfig(BaseModel):
"""

View File

@@ -25,7 +25,7 @@ from fields.document_fields import document_status_fields
from libs.login import login_required
from models.dataset import Dataset, Document, DocumentSegment
from models.model import ApiToken, UploadFile
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
from services.dataset_service import DatasetService, DocumentService
def _validate_name(name):
@@ -85,12 +85,6 @@ class DatasetListApi(Resource):
else:
item['embedding_available'] = True
if item.get('permission') == 'partial_members':
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(item['id'])
item.update({'partial_member_list': part_users_list})
else:
item.update({'partial_member_list': []})
response = {
'data': data,
'has_more': len(datasets) == limit,
@@ -114,7 +108,7 @@ class DatasetListApi(Resource):
help='Invalid indexing technique.')
args = parser.parse_args()
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
@@ -146,10 +140,6 @@ class DatasetApi(Resource):
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
data = marshal(dataset, dataset_detail_fields)
if data.get('permission') == 'partial_members':
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
data.update({'partial_member_list': part_users_list})
# check embedding setting
provider_manager = ProviderManager()
configurations = provider_manager.get_configurations(
@@ -173,11 +163,6 @@ class DatasetApi(Resource):
data['embedding_available'] = False
else:
data['embedding_available'] = True
if data.get('permission') == 'partial_members':
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
data.update({'partial_member_list': part_users_list})
return data, 200
@setup_required
@@ -203,21 +188,17 @@ class DatasetApi(Resource):
nullable=True,
help='Invalid indexing technique.')
parser.add_argument('permission', type=str, location='json', choices=(
'only_me', 'all_team_members', 'partial_members'), help='Invalid permission.'
)
'only_me', 'all_team_members'), help='Invalid permission.')
parser.add_argument('embedding_model', type=str,
location='json', help='Invalid embedding model.')
parser.add_argument('embedding_model_provider', type=str,
location='json', help='Invalid embedding model provider.')
parser.add_argument('retrieval_model', type=dict, location='json', help='Invalid retrieval model.')
parser.add_argument('partial_member_list', type=list, location='json', help='Invalid parent user list.')
args = parser.parse_args()
data = request.get_json()
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
DatasetPermissionService.check_permission(
current_user, dataset, data.get('permission'), data.get('partial_member_list')
)
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
dataset = DatasetService.update_dataset(
dataset_id_str, args, current_user)
@@ -225,17 +206,7 @@ class DatasetApi(Resource):
if dataset is None:
raise NotFound("Dataset not found.")
result_data = marshal(dataset, dataset_detail_fields)
if data.get('partial_member_list') and data.get('permission') == 'partial_members':
DatasetPermissionService.update_partial_member_list(dataset_id_str, data.get('partial_member_list'))
else:
DatasetPermissionService.clear_partial_member_list(dataset_id_str)
partial_member_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
result_data.update({'partial_member_list': partial_member_list})
return result_data, 200
return marshal(dataset, dataset_detail_fields), 200
@setup_required
@login_required
@@ -244,7 +215,7 @@ class DatasetApi(Resource):
dataset_id_str = str(dataset_id)
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor or current_user.is_dataset_operator:
if not current_user.is_editor:
raise Forbidden()
try:
@@ -598,27 +569,6 @@ class DatasetErrorDocs(Resource):
}, 200
class DatasetPermissionUserListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
partial_members_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
return {
'data': partial_members_list,
}, 200
api.add_resource(DatasetListApi, '/datasets')
api.add_resource(DatasetApi, '/datasets/<uuid:dataset_id>')
api.add_resource(DatasetUseCheckApi, '/datasets/<uuid:dataset_id>/use-check')
@@ -632,4 +582,3 @@ api.add_resource(DatasetApiDeleteApi, '/datasets/api-keys/<uuid:api_key_id>')
api.add_resource(DatasetApiBaseUrlApi, '/datasets/api-base-info')
api.add_resource(DatasetRetrievalSettingApi, '/datasets/retrieval-setting')
api.add_resource(DatasetRetrievalSettingMockApi, '/datasets/retrieval-setting/<string:vector_type>')
api.add_resource(DatasetPermissionUserListApi, '/datasets/<uuid:dataset_id>/permission-part-users')

View File

@@ -228,7 +228,7 @@ class DatasetDocumentListApi(Resource):
raise NotFound('Dataset not found.')
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_dataset_editor:
if not current_user.is_editor:
raise Forbidden()
try:
@@ -294,11 +294,6 @@ class DatasetInitApi(Resource):
parser.add_argument('retrieval_model', type=dict, required=False, nullable=False,
location='json')
args = parser.parse_args()
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
if not current_user.is_dataset_editor:
raise Forbidden()
if args['indexing_technique'] == 'high_quality':
try:
model_manager = ModelManager()
@@ -762,19 +757,15 @@ class DocumentStatusApi(DocumentResource):
dataset = DatasetService.get_dataset(dataset_id)
if dataset is None:
raise NotFound("Dataset not found.")
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_dataset_editor:
raise Forbidden()
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
# check user's permission
DatasetService.check_dataset_permission(dataset, current_user)
document = self.get_document(dataset_id, document_id)
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
indexing_cache_key = 'document_{}_indexing'.format(document.id)
cache_result = redis_client.get(indexing_cache_key)
if cache_result is not None:
@@ -964,11 +955,10 @@ class DocumentRenameApi(DocumentResource):
@account_initialization_required
@marshal_with(document_fields)
def post(self, dataset_id, document_id):
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
if not current_user.is_dataset_editor:
# The role of the current user in the ta table must be admin or owner
if not current_user.is_admin_or_owner:
raise Forbidden()
dataset = DatasetService.get_dataset(dataset_id)
DatasetService.check_dataset_operator_permission(current_user, dataset)
parser = reqparse.RequestParser()
parser.add_argument('name', type=str, required=True, nullable=False, location='json')
args = parser.parse_args()

View File

@@ -36,7 +36,7 @@ class TagListApi(Resource):
@account_initialization_required
def post(self):
# The role of the current user in the ta table must be admin, owner, or editor
if not (current_user.is_editor or current_user.is_dataset_editor):
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
@@ -68,7 +68,7 @@ class TagUpdateDeleteApi(Resource):
def patch(self, tag_id):
tag_id = str(tag_id)
# The role of the current user in the ta table must be admin, owner, or editor
if not (current_user.is_editor or current_user.is_dataset_editor):
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
@@ -109,8 +109,8 @@ class TagBindingCreateApi(Resource):
@login_required
@account_initialization_required
def post(self):
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
if not (current_user.is_editor or current_user.is_dataset_editor):
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
@@ -134,8 +134,8 @@ class TagBindingDeleteApi(Resource):
@login_required
@account_initialization_required
def post(self):
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
if not (current_user.is_editor or current_user.is_dataset_editor):
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()

View File

@@ -131,20 +131,7 @@ class MemberUpdateRoleApi(Resource):
return {'result': 'success'}
class DatasetOperatorMemberListApi(Resource):
"""List all members of current tenant."""
@setup_required
@login_required
@account_initialization_required
@marshal_with(account_with_role_list_fields)
def get(self):
members = TenantService.get_dataset_operator_members(current_user.current_tenant)
return {'result': 'success', 'accounts': members}, 200
api.add_resource(MemberListApi, '/workspaces/current/members')
api.add_resource(MemberInviteEmailApi, '/workspaces/current/members/invite-email')
api.add_resource(MemberCancelInviteApi, '/workspaces/current/members/<uuid:member_id>')
api.add_resource(MemberUpdateRoleApi, '/workspaces/current/members/<uuid:member_id>/update-role')
api.add_resource(DatasetOperatorMemberListApi, '/workspaces/current/dataset-operators')

View File

@@ -1,42 +0,0 @@
"""add table dataset_permissions
Revision ID: 7e6a8693e07a
Revises: 4ff534e1eb11
Create Date: 2024-06-25 03:20:46.012193
"""
import sqlalchemy as sa
from alembic import op
import models as models
# revision identifiers, used by Alembic.
revision = '7e6a8693e07a'
down_revision = 'b2602e131636'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('dataset_permissions',
sa.Column('id', models.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('dataset_id', models.StringUUID(), nullable=False),
sa.Column('account_id', models.StringUUID(), nullable=False),
sa.Column('has_permission', sa.Boolean(), server_default=sa.text('true'), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
sa.PrimaryKeyConstraint('id', name='dataset_permission_pkey')
)
with op.batch_alter_table('dataset_permissions', schema=None) as batch_op:
batch_op.create_index('idx_dataset_permissions_account_id', ['account_id'], unique=False)
batch_op.create_index('idx_dataset_permissions_dataset_id', ['dataset_id'], unique=False)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('dataset_permissions', schema=None) as batch_op:
batch_op.drop_index('idx_dataset_permissions_dataset_id')
batch_op.drop_index('idx_dataset_permissions_account_id')
op.drop_table('dataset_permissions')
# ### end Alembic commands ###

View File

@@ -80,10 +80,6 @@ class Account(UserMixin, db.Model):
self._current_tenant = tenant
@property
def current_role(self):
return self._current_tenant.current_role
def get_status(self) -> AccountStatus:
status_str = self.status
return AccountStatus(status_str)
@@ -114,14 +110,6 @@ class Account(UserMixin, db.Model):
def is_editor(self):
return TenantAccountRole.is_editing_role(self._current_tenant.current_role)
@property
def is_dataset_editor(self):
return TenantAccountRole.is_dataset_edit_role(self._current_tenant.current_role)
@property
def is_dataset_operator(self):
return self._current_tenant.current_role == TenantAccountRole.DATASET_OPERATOR
class TenantStatus(str, enum.Enum):
NORMAL = 'normal'
ARCHIVE = 'archive'
@@ -132,12 +120,10 @@ class TenantAccountRole(str, enum.Enum):
ADMIN = 'admin'
EDITOR = 'editor'
NORMAL = 'normal'
DATASET_OPERATOR = 'dataset_operator'
@staticmethod
def is_valid_role(role: str) -> bool:
return role and role in {TenantAccountRole.OWNER, TenantAccountRole.ADMIN, TenantAccountRole.EDITOR,
TenantAccountRole.NORMAL, TenantAccountRole.DATASET_OPERATOR}
return role and role in {TenantAccountRole.OWNER, TenantAccountRole.ADMIN, TenantAccountRole.EDITOR, TenantAccountRole.NORMAL}
@staticmethod
def is_privileged_role(role: str) -> bool:
@@ -145,17 +131,12 @@ class TenantAccountRole(str, enum.Enum):
@staticmethod
def is_non_owner_role(role: str) -> bool:
return role and role in {TenantAccountRole.ADMIN, TenantAccountRole.EDITOR, TenantAccountRole.NORMAL,
TenantAccountRole.DATASET_OPERATOR}
return role and role in {TenantAccountRole.ADMIN, TenantAccountRole.EDITOR, TenantAccountRole.NORMAL}
@staticmethod
def is_editing_role(role: str) -> bool:
return role and role in {TenantAccountRole.OWNER, TenantAccountRole.ADMIN, TenantAccountRole.EDITOR}
@staticmethod
def is_dataset_edit_role(role: str) -> bool:
return role and role in {TenantAccountRole.OWNER, TenantAccountRole.ADMIN, TenantAccountRole.EDITOR,
TenantAccountRole.DATASET_OPERATOR}
class Tenant(db.Model):
__tablename__ = 'tenants'
@@ -191,7 +172,6 @@ class TenantAccountJoinRole(enum.Enum):
OWNER = 'owner'
ADMIN = 'admin'
NORMAL = 'normal'
DATASET_OPERATOR = 'dataset_operator'
class TenantAccountJoin(db.Model):

View File

@@ -663,18 +663,3 @@ class DatasetCollectionBinding(db.Model):
type = db.Column(db.String(40), server_default=db.text("'dataset'::character varying"), nullable=False)
collection_name = db.Column(db.String(64), nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
class DatasetPermission(db.Model):
__tablename__ = 'dataset_permissions'
__table_args__ = (
db.PrimaryKeyConstraint('id', name='dataset_permission_pkey'),
db.Index('idx_dataset_permissions_dataset_id', 'dataset_id'),
db.Index('idx_dataset_permissions_account_id', 'account_id')
)
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'), primary_key=True)
dataset_id = db.Column(StringUUID, nullable=False)
account_id = db.Column(StringUUID, nullable=False)
has_permission = db.Column(db.Boolean, nullable=False, server_default=db.text('true'))
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))

View File

@@ -369,28 +369,6 @@ class TenantService:
return updated_accounts
@staticmethod
def get_dataset_operator_members(tenant: Tenant) -> list[Account]:
"""Get dataset admin members"""
query = (
db.session.query(Account, TenantAccountJoin.role)
.select_from(Account)
.join(
TenantAccountJoin, Account.id == TenantAccountJoin.account_id
)
.filter(TenantAccountJoin.tenant_id == tenant.id)
.filter(TenantAccountJoin.role == 'dataset_operator')
)
# Initialize an empty list to store the updated accounts
updated_accounts = []
for account, role in query:
account.role = role
updated_accounts.append(account)
return updated_accounts
@staticmethod
def has_roles(tenant: Tenant, roles: list[TenantAccountJoinRole]) -> bool:
"""Check if user has any of the given roles for a tenant"""

View File

@@ -21,12 +21,11 @@ from events.document_event import document_was_deleted
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs import helper
from models.account import Account, TenantAccountRole
from models.account import Account
from models.dataset import (
AppDatasetJoin,
Dataset,
DatasetCollectionBinding,
DatasetPermission,
DatasetProcessRule,
DatasetQuery,
Document,
@@ -57,38 +56,22 @@ class DatasetService:
@staticmethod
def get_datasets(page, per_page, provider="vendor", tenant_id=None, user=None, search=None, tag_ids=None):
query = Dataset.query.filter(Dataset.provider == provider, Dataset.tenant_id == tenant_id)
if user:
if user.current_role == TenantAccountRole.DATASET_OPERATOR:
dataset_permission = DatasetPermission.query.filter_by(account_id=user.id).all()
if dataset_permission:
dataset_ids = [dp.dataset_id for dp in dataset_permission]
query = query.filter(Dataset.id.in_(dataset_ids))
else:
query = query.filter(db.false())
else:
permission_filter = db.or_(
Dataset.created_by == user.id,
Dataset.permission == 'all_team_members',
Dataset.permission == 'partial_members',
Dataset.permission == 'only_me'
)
query = query.filter(permission_filter)
permission_filter = db.or_(Dataset.created_by == user.id,
Dataset.permission == 'all_team_members')
else:
permission_filter = Dataset.permission == 'all_team_members'
query = query.filter(permission_filter)
query = Dataset.query.filter(
db.and_(Dataset.provider == provider, Dataset.tenant_id == tenant_id, permission_filter)) \
.order_by(Dataset.created_at.desc())
if search:
query = query.filter(Dataset.name.ilike(f'%{search}%'))
query = query.filter(db.and_(Dataset.name.ilike(f'%{search}%')))
if tag_ids:
target_ids = TagService.get_target_ids_by_tag_ids('knowledge', tenant_id, tag_ids)
if target_ids:
query = query.filter(Dataset.id.in_(target_ids))
query = query.filter(db.and_(Dataset.id.in_(target_ids)))
else:
return [], 0
datasets = query.paginate(
page=page,
per_page=per_page,
@@ -96,12 +79,6 @@ class DatasetService:
error_out=False
)
# check datasets permission,
if user and user.current_role != TenantAccountRole.DATASET_OPERATOR:
datasets.items, datasets.total = DatasetService.filter_datasets_by_permission(
user, datasets
)
return datasets.items, datasets.total
@staticmethod
@@ -125,12 +102,9 @@ class DatasetService:
@staticmethod
def get_datasets_by_ids(ids, tenant_id):
datasets = Dataset.query.filter(
Dataset.id.in_(ids),
Dataset.tenant_id == tenant_id
).paginate(
page=1, per_page=len(ids), max_per_page=len(ids), error_out=False
)
datasets = Dataset.query.filter(Dataset.id.in_(ids),
Dataset.tenant_id == tenant_id).paginate(
page=1, per_page=len(ids), max_per_page=len(ids), error_out=False)
return datasets.items, datasets.total
@staticmethod
@@ -138,8 +112,7 @@ class DatasetService:
# check if dataset name already exists
if Dataset.query.filter_by(name=name, tenant_id=tenant_id).first():
raise DatasetNameDuplicateError(
f'Dataset with name {name} already exists.'
)
f'Dataset with name {name} already exists.')
embedding_model = None
if indexing_technique == 'high_quality':
model_manager = ModelManager()
@@ -178,17 +151,13 @@ class DatasetService:
except LLMBadRequestError:
raise ValueError(
"No Embedding Model available. Please configure a valid provider "
"in the Settings -> Model Provider."
)
"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ValueError(
f"The dataset in unavailable, due to: "
f"{ex.description}"
)
raise ValueError(f"The dataset in unavailable, due to: "
f"{ex.description}")
@staticmethod
def update_dataset(dataset_id, data, user):
data.pop('partial_member_list', None)
filtered_data = {k: v for k, v in data.items() if v is not None or k == 'description'}
dataset = DatasetService.get_dataset(dataset_id)
DatasetService.check_dataset_permission(dataset, user)
@@ -221,13 +190,12 @@ class DatasetService:
except LLMBadRequestError:
raise ValueError(
"No Embedding Model available. Please configure a valid provider "
"in the Settings -> Model Provider."
)
"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ValueError(ex.description)
else:
if data['embedding_model_provider'] != dataset.embedding_model_provider or \
data['embedding_model'] != dataset.embedding_model:
data['embedding_model'] != dataset.embedding_model:
action = 'update'
try:
model_manager = ModelManager()
@@ -247,8 +215,7 @@ class DatasetService:
except LLMBadRequestError:
raise ValueError(
"No Embedding Model available. Please configure a valid provider "
"in the Settings -> Model Provider."
)
"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ValueError(ex.description)
@@ -292,41 +259,14 @@ class DatasetService:
def check_dataset_permission(dataset, user):
if dataset.tenant_id != user.current_tenant_id:
logging.debug(
f'User {user.id} does not have permission to access dataset {dataset.id}'
)
f'User {user.id} does not have permission to access dataset {dataset.id}')
raise NoPermissionError(
'You do not have permission to access this dataset.'
)
'You do not have permission to access this dataset.')
if dataset.permission == 'only_me' and dataset.created_by != user.id:
logging.debug(
f'User {user.id} does not have permission to access dataset {dataset.id}'
)
f'User {user.id} does not have permission to access dataset {dataset.id}')
raise NoPermissionError(
'You do not have permission to access this dataset.'
)
if dataset.permission == 'partial_members':
user_permission = DatasetPermission.query.filter_by(
dataset_id=dataset.id, account_id=user.id
).first()
if not user_permission and dataset.tenant_id != user.current_tenant_id and dataset.created_by != user.id:
logging.debug(
f'User {user.id} does not have permission to access dataset {dataset.id}'
)
raise NoPermissionError(
'You do not have permission to access this dataset.'
)
@staticmethod
def check_dataset_operator_permission(user: Account = None, dataset: Dataset = None):
if dataset.permission == 'only_me':
if dataset.created_by != user.id:
raise NoPermissionError('You do not have permission to access this dataset.')
elif dataset.permission == 'partial_members':
if not any(
dp.dataset_id == dataset.id for dp in DatasetPermission.query.filter_by(account_id=user.id).all()
):
raise NoPermissionError('You do not have permission to access this dataset.')
'You do not have permission to access this dataset.')
@staticmethod
def get_dataset_queries(dataset_id: str, page: int, per_page: int):
@@ -342,22 +282,6 @@ class DatasetService:
return AppDatasetJoin.query.filter(AppDatasetJoin.dataset_id == dataset_id) \
.order_by(db.desc(AppDatasetJoin.created_at)).all()
@staticmethod
def filter_datasets_by_permission(user, datasets):
dataset_permission = DatasetPermission.query.filter_by(account_id=user.id).all()
permitted_dataset_ids = {dp.dataset_id for dp in dataset_permission} if dataset_permission else set()
filtered_datasets = [
dataset for dataset in datasets if
(dataset.permission == 'all_team_members') or
(dataset.permission == 'only_me' and dataset.created_by == user.id) or
(dataset.id in permitted_dataset_ids)
]
filtered_count = len(filtered_datasets)
return filtered_datasets, filtered_count
class DocumentService:
DEFAULT_RULES = {
@@ -623,7 +547,6 @@ class DocumentService:
redis_client.setex(sync_indexing_cache_key, 600, 1)
sync_website_document_indexing_task.delay(dataset_id, document.id)
@staticmethod
def get_documents_position(dataset_id):
document = Document.query.filter_by(dataset_id=dataset_id).order_by(Document.position.desc()).first()
@@ -633,11 +556,9 @@ class DocumentService:
return 1
@staticmethod
def save_document_with_dataset_id(
dataset: Dataset, document_data: dict,
account: Account, dataset_process_rule: Optional[DatasetProcessRule] = None,
created_from: str = 'web'
):
def save_document_with_dataset_id(dataset: Dataset, document_data: dict,
account: Account, dataset_process_rule: Optional[DatasetProcessRule] = None,
created_from: str = 'web'):
# check document limit
features = FeatureService.get_features(current_user.current_tenant_id)
@@ -667,7 +588,7 @@ class DocumentService:
if not dataset.indexing_technique:
if 'indexing_technique' not in document_data \
or document_data['indexing_technique'] not in Dataset.INDEXING_TECHNIQUE_LIST:
or document_data['indexing_technique'] not in Dataset.INDEXING_TECHNIQUE_LIST:
raise ValueError("Indexing technique is required")
dataset.indexing_technique = document_data["indexing_technique"]
@@ -697,8 +618,7 @@ class DocumentService:
}
dataset.retrieval_model = document_data.get('retrieval_model') if document_data.get(
'retrieval_model'
) else default_retrieval_model
'retrieval_model') else default_retrieval_model
documents = []
batch = time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999))
@@ -766,14 +686,12 @@ class DocumentService:
documents.append(document)
duplicate_document_ids.append(document.id)
continue
document = DocumentService.build_document(
dataset, dataset_process_rule.id,
document_data["data_source"]["type"],
document_data["doc_form"],
document_data["doc_language"],
data_source_info, created_from, position,
account, file_name, batch
)
document = DocumentService.build_document(dataset, dataset_process_rule.id,
document_data["data_source"]["type"],
document_data["doc_form"],
document_data["doc_language"],
data_source_info, created_from, position,
account, file_name, batch)
db.session.add(document)
db.session.flush()
document_ids.append(document.id)
@@ -814,14 +732,12 @@ class DocumentService:
"notion_page_icon": page['page_icon'],
"type": page['type']
}
document = DocumentService.build_document(
dataset, dataset_process_rule.id,
document_data["data_source"]["type"],
document_data["doc_form"],
document_data["doc_language"],
data_source_info, created_from, position,
account, page['page_name'], batch
)
document = DocumentService.build_document(dataset, dataset_process_rule.id,
document_data["data_source"]["type"],
document_data["doc_form"],
document_data["doc_language"],
data_source_info, created_from, position,
account, page['page_name'], batch)
db.session.add(document)
db.session.flush()
document_ids.append(document.id)
@@ -843,14 +759,12 @@ class DocumentService:
'only_main_content': website_info.get('only_main_content', False),
'mode': 'crawl',
}
document = DocumentService.build_document(
dataset, dataset_process_rule.id,
document_data["data_source"]["type"],
document_data["doc_form"],
document_data["doc_language"],
data_source_info, created_from, position,
account, url, batch
)
document = DocumentService.build_document(dataset, dataset_process_rule.id,
document_data["data_source"]["type"],
document_data["doc_form"],
document_data["doc_language"],
data_source_info, created_from, position,
account, url, batch)
db.session.add(document)
db.session.flush()
document_ids.append(document.id)
@@ -871,16 +785,13 @@ class DocumentService:
can_upload_size = features.documents_upload_quota.limit - features.documents_upload_quota.size
if count > can_upload_size:
raise ValueError(
f'You have reached the limit of your subscription. Only {can_upload_size} documents can be uploaded.'
)
f'You have reached the limit of your subscription. Only {can_upload_size} documents can be uploaded.')
@staticmethod
def build_document(
dataset: Dataset, process_rule_id: str, data_source_type: str, document_form: str,
document_language: str, data_source_info: dict, created_from: str, position: int,
account: Account,
name: str, batch: str
):
def build_document(dataset: Dataset, process_rule_id: str, data_source_type: str, document_form: str,
document_language: str, data_source_info: dict, created_from: str, position: int,
account: Account,
name: str, batch: str):
document = Document(
tenant_id=dataset.tenant_id,
dataset_id=dataset.id,
@@ -899,20 +810,16 @@ class DocumentService:
@staticmethod
def get_tenant_documents_count():
documents_count = Document.query.filter(
Document.completed_at.isnot(None),
Document.enabled == True,
Document.archived == False,
Document.tenant_id == current_user.current_tenant_id
).count()
documents_count = Document.query.filter(Document.completed_at.isnot(None),
Document.enabled == True,
Document.archived == False,
Document.tenant_id == current_user.current_tenant_id).count()
return documents_count
@staticmethod
def update_document_with_dataset_id(
dataset: Dataset, document_data: dict,
account: Account, dataset_process_rule: Optional[DatasetProcessRule] = None,
created_from: str = 'web'
):
def update_document_with_dataset_id(dataset: Dataset, document_data: dict,
account: Account, dataset_process_rule: Optional[DatasetProcessRule] = None,
created_from: str = 'web'):
DatasetService.check_dataset_model_setting(dataset)
document = DocumentService.get_document(dataset.id, document_data["original_document_id"])
if document.display_status != 'available':
@@ -1100,7 +1007,7 @@ class DocumentService:
DocumentService.process_rule_args_validate(args)
else:
if ('data_source' not in args and not args['data_source']) \
and ('process_rule' not in args and not args['process_rule']):
and ('process_rule' not in args and not args['process_rule']):
raise ValueError("Data source or Process rule is required")
else:
if args.get('data_source'):
@@ -1162,7 +1069,7 @@ class DocumentService:
raise ValueError("Process rule rules is invalid")
if 'pre_processing_rules' not in args['process_rule']['rules'] \
or args['process_rule']['rules']['pre_processing_rules'] is None:
or args['process_rule']['rules']['pre_processing_rules'] is None:
raise ValueError("Process rule pre_processing_rules is required")
if not isinstance(args['process_rule']['rules']['pre_processing_rules'], list):
@@ -1187,21 +1094,21 @@ class DocumentService:
args['process_rule']['rules']['pre_processing_rules'] = list(unique_pre_processing_rule_dicts.values())
if 'segmentation' not in args['process_rule']['rules'] \
or args['process_rule']['rules']['segmentation'] is None:
or args['process_rule']['rules']['segmentation'] is None:
raise ValueError("Process rule segmentation is required")
if not isinstance(args['process_rule']['rules']['segmentation'], dict):
raise ValueError("Process rule segmentation is invalid")
if 'separator' not in args['process_rule']['rules']['segmentation'] \
or not args['process_rule']['rules']['segmentation']['separator']:
or not args['process_rule']['rules']['segmentation']['separator']:
raise ValueError("Process rule segmentation separator is required")
if not isinstance(args['process_rule']['rules']['segmentation']['separator'], str):
raise ValueError("Process rule segmentation separator is invalid")
if 'max_tokens' not in args['process_rule']['rules']['segmentation'] \
or not args['process_rule']['rules']['segmentation']['max_tokens']:
or not args['process_rule']['rules']['segmentation']['max_tokens']:
raise ValueError("Process rule segmentation max_tokens is required")
if not isinstance(args['process_rule']['rules']['segmentation']['max_tokens'], int):
@@ -1237,7 +1144,7 @@ class DocumentService:
raise ValueError("Process rule rules is invalid")
if 'pre_processing_rules' not in args['process_rule']['rules'] \
or args['process_rule']['rules']['pre_processing_rules'] is None:
or args['process_rule']['rules']['pre_processing_rules'] is None:
raise ValueError("Process rule pre_processing_rules is required")
if not isinstance(args['process_rule']['rules']['pre_processing_rules'], list):
@@ -1262,21 +1169,21 @@ class DocumentService:
args['process_rule']['rules']['pre_processing_rules'] = list(unique_pre_processing_rule_dicts.values())
if 'segmentation' not in args['process_rule']['rules'] \
or args['process_rule']['rules']['segmentation'] is None:
or args['process_rule']['rules']['segmentation'] is None:
raise ValueError("Process rule segmentation is required")
if not isinstance(args['process_rule']['rules']['segmentation'], dict):
raise ValueError("Process rule segmentation is invalid")
if 'separator' not in args['process_rule']['rules']['segmentation'] \
or not args['process_rule']['rules']['segmentation']['separator']:
or not args['process_rule']['rules']['segmentation']['separator']:
raise ValueError("Process rule segmentation separator is required")
if not isinstance(args['process_rule']['rules']['segmentation']['separator'], str):
raise ValueError("Process rule segmentation separator is invalid")
if 'max_tokens' not in args['process_rule']['rules']['segmentation'] \
or not args['process_rule']['rules']['segmentation']['max_tokens']:
or not args['process_rule']['rules']['segmentation']['max_tokens']:
raise ValueError("Process rule segmentation max_tokens is required")
if not isinstance(args['process_rule']['rules']['segmentation']['max_tokens'], int):
@@ -1530,16 +1437,12 @@ class SegmentService:
class DatasetCollectionBindingService:
@classmethod
def get_dataset_collection_binding(
cls, provider_name: str, model_name: str,
collection_type: str = 'dataset'
) -> DatasetCollectionBinding:
def get_dataset_collection_binding(cls, provider_name: str, model_name: str,
collection_type: str = 'dataset') -> DatasetCollectionBinding:
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
filter(
DatasetCollectionBinding.provider_name == provider_name,
DatasetCollectionBinding.model_name == model_name,
DatasetCollectionBinding.type == collection_type
). \
filter(DatasetCollectionBinding.provider_name == provider_name,
DatasetCollectionBinding.model_name == model_name,
DatasetCollectionBinding.type == collection_type). \
order_by(DatasetCollectionBinding.created_at). \
first()
@@ -1555,76 +1458,12 @@ class DatasetCollectionBindingService:
return dataset_collection_binding
@classmethod
def get_dataset_collection_binding_by_id_and_type(
cls, collection_binding_id: str,
collection_type: str = 'dataset'
) -> DatasetCollectionBinding:
def get_dataset_collection_binding_by_id_and_type(cls, collection_binding_id: str,
collection_type: str = 'dataset') -> DatasetCollectionBinding:
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
filter(
DatasetCollectionBinding.id == collection_binding_id,
DatasetCollectionBinding.type == collection_type
). \
filter(DatasetCollectionBinding.id == collection_binding_id,
DatasetCollectionBinding.type == collection_type). \
order_by(DatasetCollectionBinding.created_at). \
first()
return dataset_collection_binding
class DatasetPermissionService:
@classmethod
def get_dataset_partial_member_list(cls, dataset_id):
user_list_query = db.session.query(
DatasetPermission.account_id,
).filter(
DatasetPermission.dataset_id == dataset_id
).all()
user_list = []
for user in user_list_query:
user_list.append(user.account_id)
return user_list
@classmethod
def update_partial_member_list(cls, dataset_id, user_list):
try:
db.session.query(DatasetPermission).filter(DatasetPermission.dataset_id == dataset_id).delete()
permissions = []
for user in user_list:
permission = DatasetPermission(
dataset_id=dataset_id,
account_id=user['user_id'],
)
permissions.append(permission)
db.session.add_all(permissions)
db.session.commit()
except Exception as e:
db.session.rollback()
raise e
@classmethod
def check_permission(cls, user, dataset, requested_permission, requested_partial_member_list):
if not user.is_dataset_editor:
raise NoPermissionError('User does not have permission to edit this dataset.')
if user.is_dataset_operator and dataset.permission != requested_permission:
raise NoPermissionError('Dataset operators cannot change the dataset permissions.')
if user.is_dataset_operator and requested_permission == 'partial_members':
if not requested_partial_member_list:
raise ValueError('Partial member list is required when setting to partial members.')
local_member_list = cls.get_dataset_partial_member_list(dataset.id)
request_member_list = [user['user_id'] for user in requested_partial_member_list]
if set(local_member_list) != set(request_member_list):
raise ValueError('Dataset operators cannot change the dataset permissions.')
@classmethod
def clear_partial_member_list(cls, dataset_id):
try:
db.session.query(DatasetPermission).filter(DatasetPermission.dataset_id == dataset_id).delete()
db.session.commit()
except Exception as e:
db.session.rollback()
raise e

View File

@@ -30,7 +30,6 @@ class FeatureModel(BaseModel):
docs_processing: str = 'standard'
can_replace_logo: bool = False
model_load_balancing_enabled: bool = False
dataset_operator_enabled: bool = False
# pydantic configs
model_config = ConfigDict(protected_namespaces=())
@@ -69,7 +68,6 @@ class FeatureService:
def _fulfill_params_from_env(cls, features: FeatureModel):
features.can_replace_logo = current_app.config['CAN_REPLACE_LOGO']
features.model_load_balancing_enabled = current_app.config['MODEL_LB_ENABLED']
features.dataset_operator_enabled = current_app.config['DATASET_OPERATOR_ENABLED']
@classmethod
def _fulfill_params_from_billing_api(cls, features: FeatureModel, tenant_id: str):