-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdisplay_graph.py
29 lines (23 loc) · 959 Bytes
/
display_graph.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import os
import subprocess
import sys
from langchain_core.runnables.graph import MermaidDrawMethod, CurveStyle
import random
def display_graph(graph, output_folder="output", file_name="graph"):
#Code to visualise the graph, we will use this in all lessons
mermaid_png = graph.get_graph(xray=1).draw_mermaid_png(
draw_method=MermaidDrawMethod.API,
curve_style= CurveStyle.NATURAL
)
# Create output folder if it doesn't exist
output_folder = "./output"
os.makedirs(output_folder, exist_ok=True)
filename = os.path.join(output_folder, f"{file_name}_{random.randint(1, 100000)}.png")
with open(filename, 'wb') as f:
f.write(mermaid_png)
if sys.platform.startswith('darwin'):
subprocess.call(('open', filename))
elif sys.platform.startswith('linux'):
subprocess.call(('xdg-open', filename))
elif sys.platform.startswith('win'):
os.startfile(filename)