Skip to content

Commit c03adcb

Browse files
authored
Fix: style checks and unittests (#12603)
1 parent 04dade2 commit c03adcb

File tree

6 files changed

+38
-71
lines changed

6 files changed

+38
-71
lines changed

api/core/helper/ssrf_proxy.py

+3-6
Original file line numberDiff line numberDiff line change
@@ -60,20 +60,17 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
6060
if response.status_code not in STATUS_FORCELIST:
6161
return response
6262
else:
63-
logging.warning(
64-
f"Received status code {response.status_code} for URL {url} which is in the force list")
63+
logging.warning(f"Received status code {response.status_code} for URL {url} which is in the force list")
6564

6665
except httpx.RequestError as e:
67-
logging.warning(f"Request to URL {url} failed on attempt {
68-
retries + 1}: {e}")
66+
logging.warning(f"Request to URL {url} failed on attempt {retries + 1}: {e}")
6967
if max_retries == 0:
7068
raise
7169

7270
retries += 1
7371
if retries <= max_retries:
7472
time.sleep(BACKOFF_FACTOR * (2 ** (retries - 1)))
75-
raise MaxRetriesExceededError(
76-
f"Reached maximum retries ({max_retries}) for URL {url}")
73+
raise MaxRetriesExceededError(f"Reached maximum retries ({max_retries}) for URL {url}")
7774

7875

7976
def get(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):

api/core/rag/datasource/vdb/lindorm/lindorm_vector.py

+15-29
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,7 @@
1717
from models.dataset import Dataset
1818

1919
logger = logging.getLogger(__name__)
20-
logging.basicConfig(level=logging.INFO,
21-
format="%(asctime)s - %(levelname)s - %(message)s")
20+
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
2221
logging.getLogger("lindorm").setLevel(logging.WARN)
2322

2423
ROUTING_FIELD = "routing_field"
@@ -135,8 +134,7 @@ def delete_by_ids(self, ids: list[str]) -> None:
135134
self._client.delete(index=self._collection_name, id=id, params=params)
136135
self.refresh()
137136
else:
138-
logger.warning(
139-
f"DELETE BY ID: ID {id} does not exist in the index.")
137+
logger.warning(f"DELETE BY ID: ID {id} does not exist in the index.")
140138

141139
def delete(self) -> None:
142140
if self._using_ugc:
@@ -147,8 +145,7 @@ def delete(self) -> None:
147145
self.refresh()
148146
else:
149147
if self._client.indices.exists(index=self._collection_name):
150-
self._client.indices.delete(
151-
index=self._collection_name, params={"timeout": 60})
148+
self._client.indices.delete(index=self._collection_name, params={"timeout": 60})
152149
logger.info("Delete index success")
153150
else:
154151
logger.warning(f"Index '{self._collection_name}' does not exist. No deletion performed.")
@@ -171,14 +168,13 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc
171168
raise ValueError("All elements in query_vector should be floats")
172169

173170
top_k = kwargs.get("top_k", 10)
174-
query = default_vector_search_query(
175-
query_vector=query_vector, k=top_k, **kwargs)
171+
query = default_vector_search_query(query_vector=query_vector, k=top_k, **kwargs)
176172
try:
177173
params = {}
178174
if self._using_ugc:
179175
params["routing"] = self._routing
180176
response = self._client.search(index=self._collection_name, body=query, params=params)
181-
except Exception as e:
177+
except Exception:
182178
logger.exception(f"Error executing vector search, query: {query}")
183179
raise
184180

@@ -224,8 +220,7 @@ def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
224220
routing=routing,
225221
routing_field=self._routing_field,
226222
)
227-
response = self._client.search(
228-
index=self._collection_name, body=full_text_query)
223+
response = self._client.search(index=self._collection_name, body=full_text_query)
229224
docs = []
230225
for hit in response["hits"]["hits"]:
231226
docs.append(
@@ -243,8 +238,7 @@ def create_collection(self, dimension: int, **kwargs):
243238
with redis_client.lock(lock_name, timeout=20):
244239
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
245240
if redis_client.get(collection_exist_cache_key):
246-
logger.info(
247-
f"Collection {self._collection_name} already exists.")
241+
logger.info(f"Collection {self._collection_name} already exists.")
248242
return
249243
if self._client.indices.exists(index=self._collection_name):
250244
logger.info(f"{self._collection_name.lower()} already exists.")
@@ -264,13 +258,10 @@ def create_collection(self, dimension: int, **kwargs):
264258
hnsw_ef_construction = kwargs.pop("hnsw_ef_construction", 500)
265259
ivfpq_m = kwargs.pop("ivfpq_m", dimension)
266260
nlist = kwargs.pop("nlist", 1000)
267-
centroids_use_hnsw = kwargs.pop(
268-
"centroids_use_hnsw", True if nlist >= 5000 else False)
261+
centroids_use_hnsw = kwargs.pop("centroids_use_hnsw", True if nlist >= 5000 else False)
269262
centroids_hnsw_m = kwargs.pop("centroids_hnsw_m", 24)
270-
centroids_hnsw_ef_construct = kwargs.pop(
271-
"centroids_hnsw_ef_construct", 500)
272-
centroids_hnsw_ef_search = kwargs.pop(
273-
"centroids_hnsw_ef_search", 100)
263+
centroids_hnsw_ef_construct = kwargs.pop("centroids_hnsw_ef_construct", 500)
264+
centroids_hnsw_ef_search = kwargs.pop("centroids_hnsw_ef_search", 100)
274265
mapping = default_text_mapping(
275266
dimension,
276267
method_name,
@@ -290,8 +281,7 @@ def create_collection(self, dimension: int, **kwargs):
290281
using_ugc=self._using_ugc,
291282
**kwargs,
292283
)
293-
self._client.indices.create(
294-
index=self._collection_name.lower(), body=mapping)
284+
self._client.indices.create(index=self._collection_name.lower(), body=mapping)
295285
redis_client.set(collection_exist_cache_key, 1, ex=3600)
296286
# logger.info(f"create index success: {self._collection_name}")
297287

@@ -396,8 +386,7 @@ def default_text_search_query(
396386
# build complex search_query when either of must/must_not/should/filter is specified
397387
if must:
398388
if not isinstance(must, list):
399-
raise RuntimeError(
400-
f"unexpected [must] clause with {type(filters)}")
389+
raise RuntimeError(f"unexpected [must] clause with {type(filters)}")
401390
if query_clause not in must:
402391
must.append(query_clause)
403392
else:
@@ -407,22 +396,19 @@ def default_text_search_query(
407396

408397
if must_not:
409398
if not isinstance(must_not, list):
410-
raise RuntimeError(
411-
f"unexpected [must_not] clause with {type(filters)}")
399+
raise RuntimeError(f"unexpected [must_not] clause with {type(filters)}")
412400
boolean_query["must_not"] = must_not
413401

414402
if should:
415403
if not isinstance(should, list):
416-
raise RuntimeError(
417-
f"unexpected [should] clause with {type(filters)}")
404+
raise RuntimeError(f"unexpected [should] clause with {type(filters)}")
418405
boolean_query["should"] = should
419406
if minimum_should_match != 0:
420407
boolean_query["minimum_should_match"] = minimum_should_match
421408

422409
if filters:
423410
if not isinstance(filters, list):
424-
raise RuntimeError(
425-
f"unexpected [filter] clause with {type(filters)}")
411+
raise RuntimeError(f"unexpected [filter] clause with {type(filters)}")
426412
boolean_query["filter"] = filters
427413

428414
search_query = {"size": k, "query": {"bool": boolean_query}}

api/core/workflow/nodes/question_classifier/question_classifier_node.py

+11-22
Original file line numberDiff line numberDiff line change
@@ -44,22 +44,19 @@ def _run(self):
4444
variable_pool = self.graph_runtime_state.variable_pool
4545

4646
# extract variables
47-
variable = variable_pool.get(
48-
node_data.query_variable_selector) if node_data.query_variable_selector else None
47+
variable = variable_pool.get(node_data.query_variable_selector) if node_data.query_variable_selector else None
4948
query = variable.value if variable else None
5049
variables = {"query": query}
5150
# fetch model config
52-
model_instance, model_config = self._fetch_model_config(
53-
node_data.model)
51+
model_instance, model_config = self._fetch_model_config(node_data.model)
5452
# fetch memory
5553
memory = self._fetch_memory(
5654
node_data_memory=node_data.memory,
5755
model_instance=model_instance,
5856
)
5957
# fetch instruction
6058
node_data.instruction = node_data.instruction or ""
61-
node_data.instruction = variable_pool.convert_template(
62-
node_data.instruction).text
59+
node_data.instruction = variable_pool.convert_template(node_data.instruction).text
6360

6461
files = (
6562
self._fetch_files(
@@ -181,15 +178,12 @@ def _extract_variable_selector_to_variable_mapping(
181178
variable_mapping = {"query": node_data.query_variable_selector}
182179
variable_selectors = []
183180
if node_data.instruction:
184-
variable_template_parser = VariableTemplateParser(
185-
template=node_data.instruction)
186-
variable_selectors.extend(
187-
variable_template_parser.extract_variable_selectors())
181+
variable_template_parser = VariableTemplateParser(template=node_data.instruction)
182+
variable_selectors.extend(variable_template_parser.extract_variable_selectors())
188183
for variable_selector in variable_selectors:
189184
variable_mapping[variable_selector.variable] = variable_selector.value_selector
190185

191-
variable_mapping = {node_id + "." + key: value for key,
192-
value in variable_mapping.items()}
186+
variable_mapping = {node_id + "." + key: value for key, value in variable_mapping.items()}
193187

194188
return variable_mapping
195189

@@ -210,8 +204,7 @@ def _calculate_rest_token(
210204
context: Optional[str],
211205
) -> int:
212206
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
213-
prompt_template = self._get_prompt_template(
214-
node_data, query, None, 2000)
207+
prompt_template = self._get_prompt_template(node_data, query, None, 2000)
215208
prompt_messages = prompt_transform.get_prompt(
216209
prompt_template=prompt_template,
217210
inputs={},
@@ -224,15 +217,13 @@ def _calculate_rest_token(
224217
)
225218
rest_tokens = 2000
226219

227-
model_context_tokens = model_config.model_schema.model_properties.get(
228-
ModelPropertyKey.CONTEXT_SIZE)
220+
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
229221
if model_context_tokens:
230222
model_instance = ModelInstance(
231223
provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
232224
)
233225

234-
curr_message_tokens = model_instance.get_llm_num_tokens(
235-
prompt_messages)
226+
curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages)
236227

237228
max_tokens = 0
238229
for parameter_rule in model_config.model_schema.parameter_rules:
@@ -273,8 +264,7 @@ def _get_prompt_template(
273264
prompt_messages: list[LLMNodeChatModelMessage] = []
274265
if model_mode == ModelMode.CHAT:
275266
system_prompt_messages = LLMNodeChatModelMessage(
276-
role=PromptMessageRole.SYSTEM, text=QUESTION_CLASSIFIER_SYSTEM_PROMPT.format(
277-
histories=memory_str)
267+
role=PromptMessageRole.SYSTEM, text=QUESTION_CLASSIFIER_SYSTEM_PROMPT.format(histories=memory_str)
278268
)
279269
prompt_messages.append(system_prompt_messages)
280270
user_prompt_message_1 = LLMNodeChatModelMessage(
@@ -315,5 +305,4 @@ def _get_prompt_template(
315305
)
316306

317307
else:
318-
raise InvalidModelTypeError(
319-
f"Model mode {model_mode} not support.")
308+
raise InvalidModelTypeError(f"Model mode {model_mode} not support.")

api/pytest.ini

+6
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,12 @@ env =
77
CODE_EXECUTION_API_KEY = dify-sandbox
88
CODE_EXECUTION_ENDPOINT = http://127.0.0.1:8194
99
CODE_MAX_STRING_LENGTH = 80000
10+
PLUGIN_API_KEY=lYkiYYT6owG+71oLerGzA7GXCgOT++6ovaezWAjpCjf+Sjc3ZtU+qUEi
11+
PLUGIN_DAEMON_URL=http://127.0.0.1:5002
12+
PLUGIN_MAX_PACKAGE_SIZE=15728640
13+
INNER_API_KEY_FOR_PLUGIN=QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1
14+
MARKETPLACE_ENABLED=true
15+
MARKETPLACE_API_URL=https://marketplace.dify.ai
1016
FIRECRAWL_API_KEY = fc-
1117
FIREWORKS_API_KEY = fw_aaaaaaaaaaaaaaaaaaaa
1218
GOOGLE_API_KEY = abcdefghijklmnopqrstuvwxyz

api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,7 @@ def test_executor_with_json_body_and_object_variable():
6868
system_variables={},
6969
user_inputs={},
7070
)
71-
variable_pool.add(["pre_node_id", "object"], {
72-
"name": "John Doe", "age": 30, "email": "john@example.com"})
71+
variable_pool.add(["pre_node_id", "object"], {"name": "John Doe", "age": 30, "email": "john@example.com"})
7372

7473
# Prepare the node data
7574
node_data = HttpRequestNodeData(
@@ -124,8 +123,7 @@ def test_executor_with_json_body_and_nested_object_variable():
124123
system_variables={},
125124
user_inputs={},
126125
)
127-
variable_pool.add(["pre_node_id", "object"], {
128-
"name": "John Doe", "age": 30, "email": "john@example.com"})
126+
variable_pool.add(["pre_node_id", "object"], {"name": "John Doe", "age": 30, "email": "john@example.com"})
129127

130128
# Prepare the node data
131129
node_data = HttpRequestNodeData(

api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py

+1-10
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,6 @@
1818
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
1919

2020

21-
def test_plain_text_to_dict():
22-
assert _plain_text_to_dict("aa\n cc:") == {"aa": "", "cc": ""}
23-
assert _plain_text_to_dict("aa:bb\n cc:dd") == {"aa": "bb", "cc": "dd"}
24-
assert _plain_text_to_dict("aa:bb\n cc:dd\n") == {"aa": "bb", "cc": "dd"}
25-
assert _plain_text_to_dict("aa:bb\n\n cc : dd\n\n") == {
26-
"aa": "bb", "cc": "dd"}
27-
28-
2921
def test_http_request_node_binary_file(monkeypatch):
3022
data = HttpRequestNodeData(
3123
title="test",
@@ -191,8 +183,7 @@ def test_http_request_node_form_with_file(monkeypatch):
191183

192184
def attr_checker(*args, **kwargs):
193185
assert kwargs["data"] == {"name": "test"}
194-
assert kwargs["files"] == {
195-
"file": (None, b"test", "application/octet-stream")}
186+
assert kwargs["files"] == {"file": (None, b"test", "application/octet-stream")}
196187
return httpx.Response(200, content=b"")
197188

198189
monkeypatch.setattr(

0 commit comments

Comments
 (0)