Skip to content

Commit

Permalink
Several changes to generalize loading and plotting
Browse files Browse the repository at this point in the history
  • Loading branch information
Anastasia Baryshnikova committed Nov 26, 2019
1 parent 224adfc commit 6dade73
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 27 deletions.
58 changes: 40 additions & 18 deletions safe.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#! /usr/bin/env python

import configparser
import os
import sys
Expand Down Expand Up @@ -49,7 +51,7 @@ def __init__(self,
self.path_to_attribute_file = None

self.graph = None
self.node_key_attribute = 'key'
self.node_key_attribute = 'label_orf'

self.attributes = None
self.nodes = None
Expand Down Expand Up @@ -201,9 +203,10 @@ def load_network(self, **kwargs):
self.graph = load_network_from_mat(self.path_to_network_file, verbose=self.verbose)
elif file_extension == '.gpickle':
self.graph = load_network_from_gpickle(self.path_to_network_file, verbose=self.verbose)
self.node_key_attribute = 'label_orf'
elif file_extension == '.txt':
self.graph = load_network_from_txt(self.path_to_network_file, verbose=self.verbose)
self.graph = load_network_from_txt(self.path_to_network_file,
node_key_attribute=self.node_key_attribute,
verbose=self.verbose)
elif file_extension == '.cys':
self.graph = load_network_from_cys(self.path_to_network_file, verbose=self.verbose)

Expand All @@ -217,6 +220,14 @@ def load_network(self, **kwargs):
'key': list(key_list.values()),
'label': list(label_list.values())})

def save_network(self, **kwargs):
if 'output_file' in kwargs:
output_file = kwargs['output_file']
else:
output_file = os.path.join(os.getcwd(), self.path_to_network_file + '.gpickle')

nx.write_gpickle(self.graph, output_file)

def load_attributes(self, **kwargs):

# Overwrite the global settings, if required
Expand Down Expand Up @@ -599,11 +610,16 @@ def trim_domains(self, **kwargs):
print('Removed %d domains because they were the top choice for less than %d neighborhoods.'
% (len(to_remove), self.attribute_enrichment_min_size))

def plot_network(self):
def plot_network(self, background_color='#000000'):
plot_network(self.graph, background_color=background_color)

plot_network(self.graph)
def plot_composite_network(self, show_each_domain=False, show_domain_ids=True,
save_fig=None,
background_color='#000000'):

def plot_composite_network(self, show_each_domain=False, show_domain_ids=True):
foreground_color = '#ffffff'
if background_color == '#ffffff':
foreground_color = '#000000'

domains = np.sort(self.attributes['domain'].unique())
# domains = self.domains.index.values
Expand Down Expand Up @@ -651,28 +667,28 @@ def plot_composite_network(self, show_each_domain=False, show_domain_ids=True):
figsize = (10 * ncols, 10 * nrows)

[fig, axes] = plt.subplots(nrows=nrows, ncols=ncols, figsize=figsize, sharex=True, sharey=True,
facecolor='#000000')
facecolor=background_color)
axes = axes.ravel()

# First, plot the network
ax = axes[0]
ax = plot_network(self.graph, ax=ax)
ax = plot_network(self.graph, ax=ax, background_color=background_color)

# Then, plot the composite network
axes[1].scatter(node_xy[ix, 0], node_xy[ix, 1], c=c[ix], s=60, edgecolor=None)
axes[1].set_aspect('equal')
axes[1].set_facecolor('#000000')
axes[1].set_facecolor(background_color)

# Plot a circle around the network
plot_network_contour(self.graph, axes[1])
plot_network_contour(self.graph, axes[1], background_color=background_color)

if show_domain_ids:
for domain in domains[domains > 0]:
idx = self.node2domain['primary_domain'] == domain
centroid_x = np.nanmean(node_xy[idx, 0])
centroid_y = np.nanmean(node_xy[idx, 1])
axes[1].text(centroid_x, centroid_y, str(domain),
fontdict={'size': 16, 'color': 'white', 'weight': 'bold'})
fontdict={'size': 16, 'color': foreground_color, 'weight': 'bold'})

# Then, plot each domain separately, if requested
if show_each_domain:
Expand All @@ -692,12 +708,17 @@ def plot_composite_network(self, show_each_domain=False, show_domain_ids=True):
axes[1+domain].scatter(node_xy[idx, 0], node_xy[idx, 1], c=c[idx],
s=60, edgecolor=None)
axes[1+domain].set_aspect('equal')
axes[1+domain].set_facecolor('#000000')
axes[1+domain].set_facecolor(background_color)
axes[1+domain].set_title('Domain %d\n%s' % (domain, self.domains.loc[domain, 'label']),
color='#ffffff')
plot_network_contour(self.graph, axes[1+domain])
color=foreground_color)
plot_network_contour(self.graph, axes[1+domain], background_color=background_color)

fig.set_facecolor("#000000")
fig.set_facecolor(background_color)

if save_fig:
path_to_fig = save_fig
print('Output path: %s' % path_to_fig)
plt.savefig(path_to_fig, facecolor=background_color)

def plot_sample_attributes(self, attributes=1, top_attributes_only=False,
show_network=True,
Expand All @@ -711,7 +732,7 @@ def plot_sample_attributes(self, attributes=1, top_attributes_only=False,
foreground_color = '#ffffff'
if background_color == '#ffffff':
foreground_color = '#000000'

all_attributes = self.attributes.index.values
if top_attributes_only:
all_attributes = all_attributes[self.attributes['top']]
Expand All @@ -738,7 +759,8 @@ def plot_sample_attributes(self, attributes=1, top_attributes_only=False,
ncols = np.min([len(attributes)+nax, 2])
figsize = (10*ncols, 10*nrows)

[fig, axes] = plt.subplots(nrows=nrows, ncols=ncols, figsize=figsize, sharex=True, sharey=True)
[fig, axes] = plt.subplots(nrows=nrows, ncols=ncols, figsize=figsize, sharex=True, sharey=True,
facecolor=background_color)

if isinstance(axes, np.ndarray):
axes = axes.ravel()
Expand All @@ -748,7 +770,7 @@ def plot_sample_attributes(self, attributes=1, top_attributes_only=False,
# First, plot the network (if required)
if show_network:
ax = axes[0]
_ = plot_network(self.graph, ax=ax)
ax = plot_network(self.graph, ax=ax, background_color=background_color)

score = self.nes

Expand Down
22 changes: 13 additions & 9 deletions safe_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from xml.dom import minidom


def load_network_from_txt(filename, layout='spring_embedded', verbose=True):
def load_network_from_txt(filename, layout='spring_embedded', node_key_attribute='key', verbose=True):

filename = re.sub('~', expanduser('~'), filename)
data = pd.read_table(filename, sep='\t', header=None)
Expand Down Expand Up @@ -62,7 +62,7 @@ def load_network_from_txt(filename, layout='spring_embedded', verbose=True):

for n in G:
G.nodes[n]['label'] = nodes.loc[n, 'node_label1']
G.nodes[n]['key'] = nodes.loc[n, 'node_key1']
G.nodes[n][node_key_attribute] = nodes.loc[n, 'node_key1']

# Add the edges between the nodes
edges = [tuple(x) for x in data[['node_index1', 'node_index2']].values]
Expand Down Expand Up @@ -346,35 +346,39 @@ def load_attributes(attribute_file='', node_label_order=None, mask_duplicates=Fa
return attributes, node_label_order, node2attribute


def plot_network(G, ax=None):
def plot_network(G, ax=None, background_color='#000000'):

foreground_color = '#ffffff'
if background_color == '#ffffff':
foreground_color = '#000000'

node_xy = get_node_coordinates(G)

if ax is None:
fig, ax = plt.subplots(figsize=(20, 10), facecolor='black', edgecolor='white')
fig.set_facecolor("#000000")
fig, ax = plt.subplots(figsize=(20, 10), facecolor=background_color, edgecolor=foreground_color)
fig.set_facecolor(background_color)

# Randomly sample a fraction of the edges (when network is too big)
edges = tuple(G.edges())
if len(edges) > 30000:
edges = random.sample(edges, int(len(edges)*0.1))

nx.draw(G, ax=ax, pos=node_xy, edgelist=edges,
node_color='#ffffff', edge_color='#ffffff', node_size=10, width=1, alpha=0.2)
node_color=foreground_color, edge_color=foreground_color, node_size=10, width=1, alpha=0.2)

ax.set_aspect('equal')
ax.set_facecolor('#000000')
ax.set_facecolor(background_color)

ax.grid(False)
ax.invert_yaxis()
ax.margins(0.1, 0.1)

ax.set_title('Network', color='#ffffff')
ax.set_title('Network', color=foreground_color)

plt.axis('off')

try:
fig.set_facecolor("#000000")
fig.set_facecolor(background_color)
except NameError:
pass

Expand Down

0 comments on commit 6dade73

Please sign in to comment.