Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix #12453 #12482 #12495

Merged
merged 2 commits into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion api/controllers/console/datasets/datasets_document.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,8 @@ def post(self, dataset_id):
parser.add_argument("original_document_id", type=str, required=False, location="json")
parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")

parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
parser.add_argument(
"doc_language", type=str, default="English", required=False, nullable=False, location="json"
)
Expand Down
24 changes: 17 additions & 7 deletions api/services/dataset_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,13 +792,19 @@ def save_document_with_dataset_id(
dataset.indexing_technique = knowledge_config.indexing_technique
if knowledge_config.indexing_technique == "high_quality":
model_manager = ModelManager()
embedding_model = model_manager.get_default_model_instance(
tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING
)
dataset.embedding_model = embedding_model.model
dataset.embedding_model_provider = embedding_model.provider
if knowledge_config.embedding_model and knowledge_config.embedding_model_provider:
dataset_embedding_model = knowledge_config.embedding_model
dataset_embedding_model_provider = knowledge_config.embedding_model_provider
else:
embedding_model = model_manager.get_default_model_instance(
tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING
)
dataset_embedding_model = embedding_model.model
dataset_embedding_model_provider = embedding_model.provider
dataset.embedding_model = dataset_embedding_model
dataset.embedding_model_provider = dataset_embedding_model_provider
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_model.provider, embedding_model.model
dataset_embedding_model_provider, dataset_embedding_model
)
dataset.collection_binding_id = dataset_collection_binding.id
if not dataset.retrieval_model:
Expand All @@ -810,7 +816,11 @@ def save_document_with_dataset_id(
"score_threshold_enabled": False,
}

dataset.retrieval_model = knowledge_config.retrieval_model.model_dump() or default_retrieval_model # type: ignore
dataset.retrieval_model = (
knowledge_config.retrieval_model.model_dump()
if knowledge_config.retrieval_model
else default_retrieval_model
) # type: ignore

documents = []
if knowledge_config.original_document_id:
Expand Down
5 changes: 4 additions & 1 deletion api/tasks/deal_dataset_vector_index_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):

if not dataset:
raise Exception("Dataset not found")
index_type = dataset.doc_form
index_type = dataset.doc_form or IndexType.PARAGRAPH_INDEX
index_processor = IndexProcessorFactory(index_type).init_index_processor()
if action == "remove":
index_processor.clean(dataset, None, with_keywords=False)
Expand Down Expand Up @@ -157,6 +157,9 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
{"indexing_status": "error", "error": str(e)}, synchronize_session=False
)
db.session.commit()
else:
# clean collection
index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False)

end_at = time.perf_counter()
logging.info(
Expand Down
Loading