Skip to content

Commit

Permalink
Merge branch 'main' into 1-implement-testing-and-benchmarking
Browse files Browse the repository at this point in the history
  • Loading branch information
rbasu101 committed Mar 18, 2024
2 parents bda645d + f1a35f4 commit 072d342
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 52 deletions.
10 changes: 5 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ branch = true

[tool.ruff]
src = ["src"]
exclude = ["docs/source/conf.py", "src/dgipy/graph_app.py", "src/dgipy/network_graph.py","tests/test_dgidb.py","tests/test_graph_app.py"]
select = [
exclude = ["docs/source/conf.py","tests/test_dgidb.py","tests/test_graph_app.py"]
lint.select = [
"F", # https://docs.astral.sh/ruff/rules/#pyflakes-f
"E", "W", # https://docs.astral.sh/ruff/rules/#pycodestyle-e-w
"I", # https://docs.astral.sh/ruff/rules/#isort-i
Expand Down Expand Up @@ -98,7 +98,7 @@ select = [
"PGH", # https://docs.astral.sh/ruff/rules/#pygrep-hooks-pgh
"RUF", # https://docs.astral.sh/ruff/rules/#ruff-specific-rules-ruf
]
fixable = [
lint.fixable = [
"I",
"F401",
"D",
Expand Down Expand Up @@ -129,15 +129,15 @@ fixable = [
# E501 - line-too-long*
# W191 - tab-indentation*
# *ignored for compatibility with formatter
ignore = [
lint.ignore = [
"ANN101", "ANN003",
"D203", "D205", "D206", "D213", "D300", "D400", "D415",
"E111", "E114", "E117", "E501",
"W191",
"E722"
]

[tool.ruff.per-file-ignores]
[tool.ruff.lint.per-file-ignores]
# ANN001 - missing-type-function-argument
# ANN102 - missing-type-cls
# ANN2 - missing-return-type
Expand Down
67 changes: 38 additions & 29 deletions src/dgipy/graph_app.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,31 @@
"""Provides functionality to create a Dash web application for interacting with drug-gene data from DGIdb"""
import dash_bootstrap_components as dbc
from dash import Input, Output, State, ctx, dash, dcc, html

from dgipy import dgidb
from dgipy import network_graph as ng


def generate_app():
def generate_app() -> dash.Dash:
"""Initialize a Dash application object with a layout designed for visualizing: drug-gene interactions, options for user interactivity, and other visual elements.
:return: a python dash app that can be run with run_server()
"""
genes = dgidb.get_gene_list()
plot = ng.generate_plotly(None)
app = dash.Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP])

set_app_layout(app, plot, genes)
update_plot(app)
update_selected_node(app)
update_selected_node_display(app)
update_neighbor_dropdown(app)
update_edge_info(app)
__set_app_layout(app, plot, genes)
__update_plot(app)
__update_selected_node(app)
__update_selected_node_display(app)
__update_neighbor_dropdown(app)
__update_edge_info(app)

return app


def set_app_layout(app, plot, genes):
def __set_app_layout(app: dash.Dash, plot: ng.go.Figure, genes: list) -> None:
graph_display = dcc.Graph(
id="network-graph", figure=plot, style={"width": "100%", "height": "800px"}
)
Expand Down Expand Up @@ -86,12 +92,12 @@ def set_app_layout(app, plot, genes):
)


def update_plot(app):
def __update_plot(app: dash.Dash) -> None:
@app.callback(
[Output("graph", "data"), Output("network-graph", "figure")],
Input("gene-dropdown", "value"),
)
def update(selected_genes):
def update(selected_genes: None | list) -> tuple[dict | None, ng.go.Figure]:
if selected_genes is not None:
gene_interactions = dgidb.get_interactions(selected_genes)
updated_graph = ng.create_network(gene_interactions, selected_genes)
Expand All @@ -100,55 +106,58 @@ def update(selected_genes):
return None, ng.generate_plotly(None)


def update_selected_node(app):
def __update_selected_node(app: dash.Dash) -> None:
@app.callback(
Output("selected-node", "data"),
[Input("network-graph", "clickData"), Input("gene-dropdown", "value")],
[Input("network-graph", "click_data"), Input("gene-dropdown", "value")],
)
def update(clickData, newGene):
def update(click_data: None | dict, new_gene: None | list) -> str | dict:
if ctx.triggered_id == "gene-dropdown":
return ""
if clickData is not None and "points" in clickData:
selected_node = clickData["points"][0]
if click_data is not None and "points" in click_data:
selected_node = click_data["points"][0]
if "text" not in selected_node:
return dash.no_update
return selected_node
return dash.no_update


def update_selected_node_display(app):
def __update_selected_node_display(app: dash.Dash) -> None:
@app.callback(
Output("selected-node-text", "children"), Input("selected-node", "data")
)
def update(selected_node):
def update(selected_node: str | dict) -> str:
if selected_node != "":
return selected_node["text"]
return "No Node Selected"


def update_neighbor_dropdown(app):
def __update_neighbor_dropdown(app: dash.Dash) -> None:
@app.callback(
[Output("neighbor-dropdown", "options"), Output("neighbor-dropdown", "value")],
Input("selected-node", "data"),
)
def update(selected_node):
def update(selected_node: str | dict) -> tuple[list, None]:
if selected_node != "" and selected_node["curveNumber"] != 1:
return selected_node["customdata"], None
else:
return [], None
return [], None


def update_edge_info(app):
def __update_edge_info(app: dash.Dash) -> None:
@app.callback(
Output("edge-info-text", "children"),
[Input("selected-node", "data"), Input("neighbor-dropdown", "value")],
State("graph", "data"),
)
def update(selected_node, selected_neighbor, graph):
def update(
selected_node: str | dict, selected_neighbor: None | str, graph: None | dict
) -> str:
if selected_node == "":
return "No Edge Selected"
if selected_node["curveNumber"] == 1:
selected_data = get_node_data_from_id(graph["links"], selected_node["text"])
selected_data = __get_node_data_from_id(
graph["links"], selected_node["text"]
)
return (
"ID: "
+ str(selected_data["id"])
Expand All @@ -165,19 +174,19 @@ def update(selected_node, selected_neighbor, graph):
)
if selected_neighbor is not None:
edge_node_id = None
selected_node_is_gene = get_node_data_from_id(
selected_node_is_gene = __get_node_data_from_id(
graph["nodes"], selected_node["text"]
)["isGene"]
selected_neighbor_is_gene = get_node_data_from_id(
selected_neighbor_is_gene = __get_node_data_from_id(
graph["nodes"], selected_neighbor
)["isGene"]
if selected_node_is_gene == selected_neighbor_is_gene:
return dash.no_update
elif selected_node_is_gene:
if selected_node_is_gene:
edge_node_id = selected_node["text"] + " - " + selected_neighbor
elif selected_neighbor_is_gene:
edge_node_id = selected_neighbor + " - " + selected_node["text"]
selected_data = get_node_data_from_id(graph["links"], edge_node_id)
selected_data = __get_node_data_from_id(graph["links"], edge_node_id)
if selected_data is None:
return dash.no_update
return (
Expand All @@ -197,7 +206,7 @@ def update(selected_node, selected_neighbor, graph):
return "No Edge Selected"


def get_node_data_from_id(nodes, node_id):
def __get_node_data_from_id(nodes: list, node_id: str) -> dict | None:
for node in nodes:
if node["id"] == node_id:
return node
Expand Down
57 changes: 39 additions & 18 deletions src/dgipy/network_graph.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
"""Provides functionality to create networkx graphs and pltoly figures for network visualization"""
import networkx as nx
import pandas as pd
import plotly.graph_objects as go

PLOTLY_SEED = 7


def initalize_network(interactions, selected_genes):
def __initalize_network(interactions: pd.DataFrame, selected_genes: list) -> nx.Graph:
interactions_graph = nx.Graph()
graphed_genes = set()
for index in interactions.index:
Expand All @@ -27,7 +29,7 @@ def initalize_network(interactions, selected_genes):
return interactions_graph


def add_node_attributes(interactions_graph):
def __add_node_attributes(interactions_graph: nx.Graph) -> None:
for node in interactions_graph.nodes:
is_gene = interactions_graph.nodes[node]["isGene"]
if is_gene:
Expand All @@ -45,26 +47,37 @@ def add_node_attributes(interactions_graph):
interactions_graph.nodes[node]["node_size"] = set_size


def create_network(interactions, selected_genes):
interactions_graph = initalize_network(interactions, selected_genes)
add_node_attributes(interactions_graph)
def create_network(interactions: pd.DataFrame, selected_genes: list) -> nx.Graph:
"""Create a networkx graph representing interactions between genes and drugs
:param interactions: DataFrame containing drug-gene interaction data
:param selected_genes: List containing genes used to query interaction data
:return: a networkx graph of drug-gene interactions
"""
interactions_graph = __initalize_network(interactions, selected_genes)
__add_node_attributes(interactions_graph)
return interactions_graph


def generate_plotly(graph):
def generate_plotly(graph: nx.Graph) -> go.Figure:
"""Create a plotly graph representing interactions between genes and drugs
:param graph: networkx graph to be formatted as a plotly graph
:return: a plotly graph of drug-gene interactions
"""
layout = go.Layout(
hovermode="closest",
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
xaxis={"showgrid": False, "zeroline": False, "showticklabels": False},
yaxis={"showgrid": False, "zeroline": False, "showticklabels": False},
showlegend=True,
)
fig = go.Figure(layout=layout)

if graph is not None:
pos = nx.spring_layout(graph, seed=PLOTLY_SEED)

trace_nodes = create_trace_nodes(graph, pos)
trace_edges = create_trace_edges(graph, pos)
trace_nodes = __create_trace_nodes(graph, pos)
trace_edges = __create_trace_edges(graph, pos)

fig.add_trace(trace_edges[0])
fig.add_trace(trace_edges[1])
Expand All @@ -74,7 +87,7 @@ def generate_plotly(graph):
return fig


def create_trace_nodes(graph, pos):
def __create_trace_nodes(graph: nx.Graph, pos: dict) -> list:
nodes_by_group = {
"cyan": {
"node_x": [],
Expand Down Expand Up @@ -117,14 +130,17 @@ def create_trace_nodes(graph, pos):
nodes_by_group[node_color]["neighbors"].append(list(graph.neighbors(node)))

trace_nodes = []
for node_group, node in nodes_by_group.items():

for _, node in nodes_by_group.items():
trace_group = go.Scatter(
x=node["node_x"],
y=node["node_y"],
mode="markers",
marker=dict(
symbol="circle", size=node["node_size"], color=node["node_color"]
),
marker={
"symbol": "circle",
"size": node["node_size"],
"color": node["node_color"],
},
text=node["node_text"],
name=node["legend_name"],
customdata=node["neighbors"],
Expand All @@ -137,7 +153,7 @@ def create_trace_nodes(graph, pos):
return trace_nodes


def create_trace_edges(graph, pos):
def __create_trace_edges(graph: nx.Graph, pos: dict) -> go.Scatter:
edge_x = []
edge_y = []

Expand All @@ -163,7 +179,7 @@ def create_trace_edges(graph, pos):
x=edge_x,
y=edge_y,
mode="lines",
line=dict(width=0.5, color="gray"),
line={"width": 0.5, "color": "gray"},
hoverinfo="none",
showlegend=False,
)
Expand All @@ -181,5 +197,10 @@ def create_trace_edges(graph, pos):
return trace_edges, i_trace_edges


def generate_json(graph):
def generate_json(graph: nx.Graph) -> dict:
"""Generate a JSON representation of a networkx graph
:param graph: networkx graph to be formatted as a JSON
:return: a dictionary representing the JSON data of the graph
"""
return nx.node_link_data(graph)

0 comments on commit 072d342

Please sign in to comment.