Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
JohnJyong committed Jan 8, 2025
1 parent 67228c9 commit a572d1d
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 9 deletions.
3 changes: 2 additions & 1 deletion api/controllers/console/datasets/datasets_document.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,8 @@ def post(self, dataset_id):
parser.add_argument("original_document_id", type=str, required=False, location="json")
parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")

parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
parser.add_argument(
"doc_language", type=str, default="English", required=False, nullable=False, location="json"
)
Expand Down
20 changes: 13 additions & 7 deletions api/services/dataset_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,13 +792,19 @@ def save_document_with_dataset_id(
dataset.indexing_technique = knowledge_config.indexing_technique
if knowledge_config.indexing_technique == "high_quality":
model_manager = ModelManager()
embedding_model = model_manager.get_default_model_instance(
tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING
)
dataset.embedding_model = embedding_model.model
dataset.embedding_model_provider = embedding_model.provider
if knowledge_config.embedding_model and knowledge_config.embedding_model_provider:
dataset_embedding_model = knowledge_config.embedding_model
dataset_embedding_model_provider = knowledge_config.embedding_model_provider
else:
embedding_model = model_manager.get_default_model_instance(
tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING
)
dataset_embedding_model = embedding_model.model
dataset_embedding_model_provider = embedding_model.provider
dataset.embedding_model = dataset_embedding_model
dataset.embedding_model_provider = dataset_embedding_model_provider
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_model.provider, embedding_model.model
dataset_embedding_model_provider, dataset_embedding_model
)
dataset.collection_binding_id = dataset_collection_binding.id
if not dataset.retrieval_model:
Expand All @@ -810,7 +816,7 @@ def save_document_with_dataset_id(
"score_threshold_enabled": False,
}

dataset.retrieval_model = knowledge_config.retrieval_model.model_dump() or default_retrieval_model # type: ignore
dataset.retrieval_model = knowledge_config.retrieval_model.model_dump() if knowledge_config.retrieval_model else default_retrieval_model # type: ignore

documents = []
if knowledge_config.original_document_id:
Expand Down
5 changes: 4 additions & 1 deletion api/tasks/deal_dataset_vector_index_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):

if not dataset:
raise Exception("Dataset not found")
index_type = dataset.doc_form
index_type = dataset.doc_form or IndexType.PARAGRAPH_INDEX
index_processor = IndexProcessorFactory(index_type).init_index_processor()
if action == "remove":
index_processor.clean(dataset, None, with_keywords=False)
Expand Down Expand Up @@ -157,6 +157,9 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
{"indexing_status": "error", "error": str(e)}, synchronize_session=False
)
db.session.commit()
else:
# clean collection
index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False)

end_at = time.perf_counter()
logging.info(
Expand Down

0 comments on commit a572d1d

Please sign in to comment.