Skip to content

Commit

Permalink
refine loginfo about graprag progress (#1823)
Browse files Browse the repository at this point in the history
### What problem does this PR solve?



### Type of change

- [x] Refactoring
  • Loading branch information
KevinHuSh authored Aug 6, 2024
1 parent 3fd7db4 commit 43199c4
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 15 deletions.
3 changes: 2 additions & 1 deletion api/db/services/document_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,8 @@ def update_progress(cls):
if 0 <= t.progress < 1:
finished = False
prg += t.progress if t.progress >= 0 else 0
msg.append(t.progress_msg)
if t.progress_msg not in msg:
msg.append(t.progress_msg)
if t.progress == -1:
bad += 1
prg /= len(tsks)
Expand Down
16 changes: 11 additions & 5 deletions graphrag/community_reports_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,16 @@
import re
import traceback
from dataclasses import dataclass
from typing import Any, List

from typing import Any, List, Callable
import networkx as nx
import pandas as pd

from graphrag import leiden
from graphrag.community_report_prompt import COMMUNITY_REPORT_PROMPT
from graphrag.leiden import add_community_info2graph
from rag.llm.chat_model import Base as CompletionLLM
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, dict_has_keys_with_types
from rag.utils import num_tokens_from_string
from timeit import default_timer as timer

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -67,11 +67,14 @@ def __init__(
self._on_error = on_error or (lambda _e, _s, _d: None)
self._max_report_length = max_report_length or 1500

def __call__(self, graph: nx.Graph):
def __call__(self, graph: nx.Graph, callback: Callable | None = None):
communities: dict[str, dict[str, List]] = leiden.run(graph, {})
total = sum([len(comm.items()) for _, comm in communities.items()])
relations_df = pd.DataFrame([{"source":s, "target": t, **attr} for s, t, attr in graph.edges(data=True)])
res_str = []
res_dict = []
over, token_count = 0, 0
st = timer()
for level, comm in communities.items():
for cm_id, ents in comm.items():
weight = ents["weight"]
Expand All @@ -84,9 +87,10 @@ def __call__(self, graph: nx.Graph):
"relation_df": rela_df.to_csv(index_label="id")
}
text = perform_variable_replacements(self._extraction_prompt, variables=prompt_variables)
gen_conf = {"temperature": 0.5}
gen_conf = {"temperature": 0.3}
try:
response = self._llm.chat(text, [], gen_conf)
token_count += num_tokens_from_string(text + response)
response = re.sub(r"^[^\{]*", "", response)
response = re.sub(r"[^\}]*$", "", response)
print(response)
Expand All @@ -108,6 +112,8 @@ def __call__(self, graph: nx.Graph):
add_community_info2graph(graph, ents, response["title"])
res_str.append(self._get_text_output(response))
res_dict.append(response)
over += 1
if callback: callback(msg=f"Communities: {over}/{total}, elapsed: {timer() - st}s, used tokens: {token_count}")

return CommunityReportsResult(
structured_output=res_dict,
Expand Down
20 changes: 15 additions & 5 deletions graphrag/graph_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,14 @@
import re
import traceback
from dataclasses import dataclass
from typing import Any, Mapping
from typing import Any, Mapping, Callable
import tiktoken
from graphrag.graph_prompt import GRAPH_EXTRACTION_PROMPT, CONTINUE_PROMPT, LOOP_PROMPT
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, clean_str
from rag.llm.chat_model import Base as CompletionLLM
import networkx as nx
from rag.utils import num_tokens_from_string
from timeit import default_timer as timer

DEFAULT_TUPLE_DELIMITER = "<|>"
DEFAULT_RECORD_DELIMITER = "##"
Expand Down Expand Up @@ -103,7 +104,9 @@ def __init__(
self._loop_args = {"logit_bias": {yes[0]: 100, no[0]: 100}, "max_tokens": 1}

def __call__(
self, texts: list[str], prompt_variables: dict[str, Any] | None = None
self, texts: list[str],
prompt_variables: dict[str, Any] | None = None,
callback: Callable | None = None
) -> GraphExtractionResult:
"""Call method definition."""
if prompt_variables is None:
Expand All @@ -127,12 +130,17 @@ def __call__(
),
}

st = timer()
total = len(texts)
total_token_count = 0
for doc_index, text in enumerate(texts):
try:
# Invoke the entity extraction
result = self._process_document(text, prompt_variables)
result, token_count = self._process_document(text, prompt_variables)
source_doc_map[doc_index] = text
all_records[doc_index] = result
total_token_count += token_count
if callback: callback(msg=f"{doc_index+1}/{total}, elapsed: {timer() - st}s, used tokens: {total_token_count}")
except Exception as e:
logging.exception("error extracting graph")
self._on_error(
Expand Down Expand Up @@ -162,9 +170,11 @@ def _process_document(
**prompt_variables,
self._input_text_key: text,
}
token_count = 0
text = perform_variable_replacements(self._extraction_prompt, variables=variables)
gen_conf = {"temperature": 0.5}
gen_conf = {"temperature": 0.3}
response = self._llm.chat(text, [], gen_conf)
token_count = num_tokens_from_string(text + response)

results = response or ""
history = [{"role": "system", "content": text}, {"role": "assistant", "content": response}]
Expand All @@ -185,7 +195,7 @@ def _process_document(
if continuation != "YES":
break

return results
return results, token_count

def _process_results(
self,
Expand Down
6 changes: 3 additions & 3 deletions graphrag/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def build_knowlege_graph_chunks(tenant_id: str, chunks: List[str], callback, ent
for i in range(len(chunks)):
tkn_cnt = num_tokens_from_string(chunks[i])
if cnt+tkn_cnt >= left_token_count and texts:
threads.append(exe.submit(ext, texts, {"entity_types": entity_types}))
threads.append(exe.submit(ext, texts, {"entity_types": entity_types}, callback))
texts = []
cnt = 0
texts.append(chunks[i])
Expand All @@ -98,7 +98,7 @@ def build_knowlege_graph_chunks(tenant_id: str, chunks: List[str], callback, ent
graphs = []
for i, _ in enumerate(threads):
graphs.append(_.result().output)
callback(0.5 + 0.1*i/len(threads))
callback(0.5 + 0.1*i/len(threads), f"Entities extraction progress ... {i+1}/{len(threads)}")

graph = reduce(graph_merge, graphs)
er = EntityResolution(llm_bdl)
Expand All @@ -125,7 +125,7 @@ def build_knowlege_graph_chunks(tenant_id: str, chunks: List[str], callback, ent

callback(0.6, "Extracting community reports.")
cr = CommunityReportsExtractor(llm_bdl)
cr = cr(graph)
cr = cr(graph, callback=callback)
for community, desc in zip(cr.structured_output, cr.output):
chunk = {
"title_tks": rag_tokenizer.tokenize(community["title"]),
Expand Down
2 changes: 1 addition & 1 deletion rag/nlp/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def search(self, req, idxnm, emb_mdl=None):
es_logger.info("TOTAL: {}".format(self.es.getTotal(res)))
if self.es.getTotal(res) == 0 and "knn" in s:
bqry, _ = self.qryr.question(qst, min_match="10%")
bqry = self._add_filters(bqry)
bqry = self._add_filters(bqry, req)
s["query"] = bqry.to_dict()
s["knn"]["filter"] = bqry.to_dict()
s["knn"]["similarity"] = 0.17
Expand Down

0 comments on commit 43199c4

Please sign in to comment.