Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: style checks and unittests #12603

Merged
merged 4 commits into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions api/core/helper/ssrf_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,20 +60,17 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
if response.status_code not in STATUS_FORCELIST:
return response
else:
logging.warning(
f"Received status code {response.status_code} for URL {url} which is in the force list")
logging.warning(f"Received status code {response.status_code} for URL {url} which is in the force list")

except httpx.RequestError as e:
logging.warning(f"Request to URL {url} failed on attempt {
retries + 1}: {e}")
logging.warning(f"Request to URL {url} failed on attempt {retries + 1}: {e}")
if max_retries == 0:
raise

retries += 1
if retries <= max_retries:
time.sleep(BACKOFF_FACTOR * (2 ** (retries - 1)))
raise MaxRetriesExceededError(
f"Reached maximum retries ({max_retries}) for URL {url}")
raise MaxRetriesExceededError(f"Reached maximum retries ({max_retries}) for URL {url}")


def get(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
Expand Down
44 changes: 15 additions & 29 deletions api/core/rag/datasource/vdb/lindorm/lindorm_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@
from models.dataset import Dataset

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s")
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logging.getLogger("lindorm").setLevel(logging.WARN)

ROUTING_FIELD = "routing_field"
Expand Down Expand Up @@ -135,8 +134,7 @@ def delete_by_ids(self, ids: list[str]) -> None:
self._client.delete(index=self._collection_name, id=id, params=params)
self.refresh()
else:
logger.warning(
f"DELETE BY ID: ID {id} does not exist in the index.")
logger.warning(f"DELETE BY ID: ID {id} does not exist in the index.")

def delete(self) -> None:
if self._using_ugc:
Expand All @@ -147,8 +145,7 @@ def delete(self) -> None:
self.refresh()
else:
if self._client.indices.exists(index=self._collection_name):
self._client.indices.delete(
index=self._collection_name, params={"timeout": 60})
self._client.indices.delete(index=self._collection_name, params={"timeout": 60})
logger.info("Delete index success")
else:
logger.warning(f"Index '{self._collection_name}' does not exist. No deletion performed.")
Expand All @@ -171,14 +168,13 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc
raise ValueError("All elements in query_vector should be floats")

top_k = kwargs.get("top_k", 10)
query = default_vector_search_query(
query_vector=query_vector, k=top_k, **kwargs)
query = default_vector_search_query(query_vector=query_vector, k=top_k, **kwargs)
try:
params = {}
if self._using_ugc:
params["routing"] = self._routing
response = self._client.search(index=self._collection_name, body=query, params=params)
except Exception as e:
except Exception:
logger.exception(f"Error executing vector search, query: {query}")
raise

Expand Down Expand Up @@ -224,8 +220,7 @@ def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
routing=routing,
routing_field=self._routing_field,
)
response = self._client.search(
index=self._collection_name, body=full_text_query)
response = self._client.search(index=self._collection_name, body=full_text_query)
docs = []
for hit in response["hits"]["hits"]:
docs.append(
Expand All @@ -243,8 +238,7 @@ def create_collection(self, dimension: int, **kwargs):
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
if redis_client.get(collection_exist_cache_key):
logger.info(
f"Collection {self._collection_name} already exists.")
logger.info(f"Collection {self._collection_name} already exists.")
return
if self._client.indices.exists(index=self._collection_name):
logger.info(f"{self._collection_name.lower()} already exists.")
Expand All @@ -264,13 +258,10 @@ def create_collection(self, dimension: int, **kwargs):
hnsw_ef_construction = kwargs.pop("hnsw_ef_construction", 500)
ivfpq_m = kwargs.pop("ivfpq_m", dimension)
nlist = kwargs.pop("nlist", 1000)
centroids_use_hnsw = kwargs.pop(
"centroids_use_hnsw", True if nlist >= 5000 else False)
centroids_use_hnsw = kwargs.pop("centroids_use_hnsw", True if nlist >= 5000 else False)
centroids_hnsw_m = kwargs.pop("centroids_hnsw_m", 24)
centroids_hnsw_ef_construct = kwargs.pop(
"centroids_hnsw_ef_construct", 500)
centroids_hnsw_ef_search = kwargs.pop(
"centroids_hnsw_ef_search", 100)
centroids_hnsw_ef_construct = kwargs.pop("centroids_hnsw_ef_construct", 500)
centroids_hnsw_ef_search = kwargs.pop("centroids_hnsw_ef_search", 100)
mapping = default_text_mapping(
dimension,
method_name,
Expand All @@ -290,8 +281,7 @@ def create_collection(self, dimension: int, **kwargs):
using_ugc=self._using_ugc,
**kwargs,
)
self._client.indices.create(
index=self._collection_name.lower(), body=mapping)
self._client.indices.create(index=self._collection_name.lower(), body=mapping)
redis_client.set(collection_exist_cache_key, 1, ex=3600)
# logger.info(f"create index success: {self._collection_name}")

Expand Down Expand Up @@ -396,8 +386,7 @@ def default_text_search_query(
# build complex search_query when either of must/must_not/should/filter is specified
if must:
if not isinstance(must, list):
raise RuntimeError(
f"unexpected [must] clause with {type(filters)}")
raise RuntimeError(f"unexpected [must] clause with {type(filters)}")
if query_clause not in must:
must.append(query_clause)
else:
Expand All @@ -407,22 +396,19 @@ def default_text_search_query(

if must_not:
if not isinstance(must_not, list):
raise RuntimeError(
f"unexpected [must_not] clause with {type(filters)}")
raise RuntimeError(f"unexpected [must_not] clause with {type(filters)}")
boolean_query["must_not"] = must_not

if should:
if not isinstance(should, list):
raise RuntimeError(
f"unexpected [should] clause with {type(filters)}")
raise RuntimeError(f"unexpected [should] clause with {type(filters)}")
boolean_query["should"] = should
if minimum_should_match != 0:
boolean_query["minimum_should_match"] = minimum_should_match

if filters:
if not isinstance(filters, list):
raise RuntimeError(
f"unexpected [filter] clause with {type(filters)}")
raise RuntimeError(f"unexpected [filter] clause with {type(filters)}")
boolean_query["filter"] = filters

search_query = {"size": k, "query": {"bool": boolean_query}}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,22 +44,19 @@ def _run(self):
variable_pool = self.graph_runtime_state.variable_pool

# extract variables
variable = variable_pool.get(
node_data.query_variable_selector) if node_data.query_variable_selector else None
variable = variable_pool.get(node_data.query_variable_selector) if node_data.query_variable_selector else None
query = variable.value if variable else None
variables = {"query": query}
# fetch model config
model_instance, model_config = self._fetch_model_config(
node_data.model)
model_instance, model_config = self._fetch_model_config(node_data.model)
# fetch memory
memory = self._fetch_memory(
node_data_memory=node_data.memory,
model_instance=model_instance,
)
# fetch instruction
node_data.instruction = node_data.instruction or ""
node_data.instruction = variable_pool.convert_template(
node_data.instruction).text
node_data.instruction = variable_pool.convert_template(node_data.instruction).text

files = (
self._fetch_files(
Expand Down Expand Up @@ -181,15 +178,12 @@ def _extract_variable_selector_to_variable_mapping(
variable_mapping = {"query": node_data.query_variable_selector}
variable_selectors = []
if node_data.instruction:
variable_template_parser = VariableTemplateParser(
template=node_data.instruction)
variable_selectors.extend(
variable_template_parser.extract_variable_selectors())
variable_template_parser = VariableTemplateParser(template=node_data.instruction)
variable_selectors.extend(variable_template_parser.extract_variable_selectors())
for variable_selector in variable_selectors:
variable_mapping[variable_selector.variable] = variable_selector.value_selector

variable_mapping = {node_id + "." + key: value for key,
value in variable_mapping.items()}
variable_mapping = {node_id + "." + key: value for key, value in variable_mapping.items()}

return variable_mapping

Expand All @@ -210,8 +204,7 @@ def _calculate_rest_token(
context: Optional[str],
) -> int:
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
prompt_template = self._get_prompt_template(
node_data, query, None, 2000)
prompt_template = self._get_prompt_template(node_data, query, None, 2000)
prompt_messages = prompt_transform.get_prompt(
prompt_template=prompt_template,
inputs={},
Expand All @@ -224,15 +217,13 @@ def _calculate_rest_token(
)
rest_tokens = 2000

model_context_tokens = model_config.model_schema.model_properties.get(
ModelPropertyKey.CONTEXT_SIZE)
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
if model_context_tokens:
model_instance = ModelInstance(
provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
)

curr_message_tokens = model_instance.get_llm_num_tokens(
prompt_messages)
curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages)

max_tokens = 0
for parameter_rule in model_config.model_schema.parameter_rules:
Expand Down Expand Up @@ -273,8 +264,7 @@ def _get_prompt_template(
prompt_messages: list[LLMNodeChatModelMessage] = []
if model_mode == ModelMode.CHAT:
system_prompt_messages = LLMNodeChatModelMessage(
role=PromptMessageRole.SYSTEM, text=QUESTION_CLASSIFIER_SYSTEM_PROMPT.format(
histories=memory_str)
role=PromptMessageRole.SYSTEM, text=QUESTION_CLASSIFIER_SYSTEM_PROMPT.format(histories=memory_str)
)
prompt_messages.append(system_prompt_messages)
user_prompt_message_1 = LLMNodeChatModelMessage(
Expand Down Expand Up @@ -315,5 +305,4 @@ def _get_prompt_template(
)

else:
raise InvalidModelTypeError(
f"Model mode {model_mode} not support.")
raise InvalidModelTypeError(f"Model mode {model_mode} not support.")
6 changes: 6 additions & 0 deletions api/pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@ env =
CODE_EXECUTION_API_KEY = dify-sandbox
CODE_EXECUTION_ENDPOINT = http://127.0.0.1:8194
CODE_MAX_STRING_LENGTH = 80000
PLUGIN_API_KEY=lYkiYYT6owG+71oLerGzA7GXCgOT++6ovaezWAjpCjf+Sjc3ZtU+qUEi
PLUGIN_DAEMON_URL=http://127.0.0.1:5002
PLUGIN_MAX_PACKAGE_SIZE=15728640
INNER_API_KEY_FOR_PLUGIN=QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1
MARKETPLACE_ENABLED=true
MARKETPLACE_API_URL=https://marketplace.dify.ai
FIRECRAWL_API_KEY = fc-
FIREWORKS_API_KEY = fw_aaaaaaaaaaaaaaaaaaaa
GOOGLE_API_KEY = abcdefghijklmnopqrstuvwxyz
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,7 @@ def test_executor_with_json_body_and_object_variable():
system_variables={},
user_inputs={},
)
variable_pool.add(["pre_node_id", "object"], {
"name": "John Doe", "age": 30, "email": "john@example.com"})
variable_pool.add(["pre_node_id", "object"], {"name": "John Doe", "age": 30, "email": "john@example.com"})

# Prepare the node data
node_data = HttpRequestNodeData(
Expand Down Expand Up @@ -124,8 +123,7 @@ def test_executor_with_json_body_and_nested_object_variable():
system_variables={},
user_inputs={},
)
variable_pool.add(["pre_node_id", "object"], {
"name": "John Doe", "age": 30, "email": "john@example.com"})
variable_pool.add(["pre_node_id", "object"], {"name": "John Doe", "age": 30, "email": "john@example.com"})

# Prepare the node data
node_data = HttpRequestNodeData(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,6 @@
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType


def test_plain_text_to_dict():
assert _plain_text_to_dict("aa\n cc:") == {"aa": "", "cc": ""}
assert _plain_text_to_dict("aa:bb\n cc:dd") == {"aa": "bb", "cc": "dd"}
assert _plain_text_to_dict("aa:bb\n cc:dd\n") == {"aa": "bb", "cc": "dd"}
assert _plain_text_to_dict("aa:bb\n\n cc : dd\n\n") == {
"aa": "bb", "cc": "dd"}


def test_http_request_node_binary_file(monkeypatch):
data = HttpRequestNodeData(
title="test",
Expand Down Expand Up @@ -191,8 +183,7 @@ def test_http_request_node_form_with_file(monkeypatch):

def attr_checker(*args, **kwargs):
assert kwargs["data"] == {"name": "test"}
assert kwargs["files"] == {
"file": (None, b"test", "application/octet-stream")}
assert kwargs["files"] == {"file": (None, b"test", "application/octet-stream")}
return httpx.Response(200, content=b"")

monkeypatch.setattr(
Expand Down
Loading