# -*- coding: utf-8 -*-
r"""
The :mod:`pygsp.plotting` module implements functionality to plot PyGSP objects
with a `pyqtgraph <http://www.pyqtgraph.org>`_ or `matplotlib
<https://matplotlib.org>`_ drawing backend (which can be controlled by the
:data:`BACKEND` constant or individually for each plotting call):
* graphs from :mod:`pygsp.graphs` with :func:`plot_graph`,
:func:`plot_spectrogram`, and :func:`plot_signal`,
* filters from :mod:`pygsp.filters` with :func:`plot_filter`.
.. data:: BACKEND
Indicates which drawing backend to use if none are provided to the plotting
functions. Should be either 'matplotlib' or 'pyqtgraph'. In general
pyqtgraph is better for interactive exploration while matplotlib is better
at generating figures to be included in papers or elsewhere.
"""
from __future__ import division
import traceback
import numpy as np
from pygsp import utils
_logger = utils.build_logger(__name__)
try:
import matplotlib.pyplot as plt
# Not used directly, but needed for 3D projection.
from mpl_toolkits.mplot3d import Axes3D # noqa
_PLT_IMPORT = True
except Exception:
_logger.error('Cannot import packages for matplotlib: {}'.format(
traceback.format_exc()))
_PLT_IMPORT = False
try:
import pyqtgraph as qtg
import pyqtgraph.opengl as gl
from pyqtgraph.Qt import QtGui
_QTG_IMPORT = True
except Exception:
_logger.error('Cannot import packages for pyqtgraph: {}'.format(
traceback.format_exc()))
_QTG_IMPORT = False
BACKEND = 'pyqtgraph'
_qtg_windows = []
_qtg_widgets = []
_plt_figures = []
def _plt_handle_figure(plot):
def inner(obj, *args, **kwargs):
# Create a figure and an axis if none were passed.
if 'ax' not in kwargs.keys():
fig = plt.figure()
global _plt_figures
_plt_figures.append(fig)
if (hasattr(obj, 'coords') and obj.coords.ndim == 2 and
obj.coords.shape[1] == 3):
ax = fig.add_subplot(111, projection='3d')
else:
ax = fig.add_subplot(111)
kwargs.update(ax=ax)
save_as = kwargs.pop('save_as', None)
plot_name = kwargs.pop('plot_name', '')
plot(obj, *args, **kwargs)
kwargs['ax'].set_title(plot_name)
try:
if save_as is not None:
fig.savefig(save_as + '.png')
fig.savefig(save_as + '.pdf')
else:
fig.show(warn=False)
except NameError:
# No figure created, an axis was passed.
pass
return inner
[docs]def close_all():
r"""
Close all opened windows.
"""
# Windows can be closed by releasing all references to them so they can be
# garbage collected. May not be necessary to call close().
global _qtg_windows
for window in _qtg_windows:
window.close()
_qtg_windows = []
global _qtg_widgets
for widget in _qtg_widgets:
widget.close()
_qtg_widgets = []
global _plt_figures
for fig in _plt_figures:
plt.close(fig)
_plt_figures = []
[docs]def show(*args, **kwargs):
r"""
Show created figures.
Alias to plt.show().
By default, showing plots does not block the prompt.
"""
plt.show(*args, **kwargs)
[docs]def close(*args, **kwargs):
r"""
Close created figures.
Alias to plt.close().
"""
plt.close(*args, **kwargs)
[docs]def plot(O, **kwargs):
r"""
Main plotting function.
This convenience function either calls :func:`plot_graph` or
:func:`plot_filter` given the type of the passed object. Parameters can be
passed to those functions.
Parameters
----------
O : Graph, Filter
object to plot
Examples
--------
>>> from pygsp import plotting
>>> G = graphs.Logo()
>>> plotting.plot(G)
"""
try:
O.plot(**kwargs)
except AttributeError:
raise TypeError('Unrecognized object, i.e. not a Graph or Filter.')
[docs]def plot_graph(G, backend=None, **kwargs):
r"""
Plot a graph or a list of graphs.
Parameters
----------
G : Graph
Graph to plot.
show_edges : bool
True to draw edges, false to only draw vertices.
Default True if less than 10,000 edges to draw.
Note that drawing a large number of edges might be particularly slow.
backend: {'matplotlib', 'pyqtgraph'}
Defines the drawing backend to use. Defaults to :data:`BACKEND`.
vertex_size : float
Size of circle representing each node.
plot_name : str
name of the plot
save_as : str
Whether to save the plot as save_as.png and save_as.pdf. Shown in a
window if None (default). Only available with the matplotlib backend.
ax : matplotlib.axes
Axes where to draw the graph. Optional, created if not passed. Only
available with the matplotlib backend.
Examples
--------
>>> from pygsp import plotting
>>> G = graphs.Logo()
>>> plotting.plot_graph(G)
"""
if not hasattr(G, 'coords'):
raise AttributeError('Graph has no coordinate set. '
'Please run G.set_coordinates() first.')
if (G.coords.ndim != 2) or (G.coords.shape[1] not in [2, 3]):
raise AttributeError('Coordinates should be in 2D or 3D space.')
kwargs['show_edges'] = kwargs.pop('show_edges', G.Ne < 10e3)
default = G.plotting['vertex_size']
kwargs['vertex_size'] = kwargs.pop('vertex_size', default)
plot_name = u'{}\nG.N={} nodes, G.Ne={} edges'.format(G.gtype, G.N, G.Ne)
kwargs['plot_name'] = kwargs.pop('plot_name', plot_name)
if backend is None:
backend = BACKEND
G = _handle_directed(G)
if backend == 'pyqtgraph' and _QTG_IMPORT:
_qtg_plot_graph(G, **kwargs)
elif backend == 'matplotlib' and _PLT_IMPORT:
_plt_plot_graph(G, **kwargs)
else:
raise ValueError('The {} backend is not available.'.format(backend))
@_plt_handle_figure
def _plt_plot_graph(G, show_edges, vertex_size, ax):
# TODO handling when G is a list of graphs
# TODO integrate param when G is a clustered graph
if show_edges:
if G.is_directed():
raise NotImplementedError
else:
if G.coords.shape[1] == 2:
x, y = _get_coords(G)
ax.plot(x, y, linewidth=G.plotting['edge_width'],
color=G.plotting['edge_color'],
linestyle=G.plotting['edge_style'],
marker='o', markersize=vertex_size/10,
markerfacecolor=G.plotting['vertex_color'],
markeredgecolor=G.plotting['vertex_color'])
if G.coords.shape[1] == 3:
# TODO: very dirty. Cannot we prepare a set of lines?
x, y, z = _get_coords(G)
for i in range(0, x.size, 2):
x2, y2, z2 = x[i:i+2], y[i:i+2], z[i:i+2]
ax.plot(x2, y2, z2, linewidth=G.plotting['edge_width'],
color=G.plotting['edge_color'],
linestyle=G.plotting['edge_style'],
marker='o', markersize=vertex_size/10,
markerfacecolor=G.plotting['vertex_color'],
markeredgecolor=G.plotting['vertex_color'])
else:
# TODO: is ax.plot(G.coords[:, 0], G.coords[:, 1], 'bo') faster?
if G.coords.shape[1] == 2:
ax.scatter(G.coords[:, 0], G.coords[:, 1], marker='o',
s=vertex_size,
c=G.plotting['vertex_color'])
if G.coords.shape[1] == 3:
ax.scatter(G.coords[:, 0], G.coords[:, 1], G.coords[:, 2],
marker='o', s=vertex_size,
c=G.plotting['vertex_color'])
if G.coords.shape[1] == 3:
try:
ax.view_init(elev=G.plotting['elevation'],
azim=G.plotting['azimuth'])
ax.dist = G.plotting['distance']
except KeyError:
pass
def _qtg_plot_graph(G, show_edges, vertex_size, plot_name):
# TODO handling when G is a list of graphs
if G.is_directed():
raise NotImplementedError
else:
if G.coords.shape[1] == 2:
window = qtg.GraphicsWindow()
window.setWindowTitle(plot_name)
view = window.addViewBox()
view.setAspectLocked()
if show_edges:
pen = tuple(np.array(G.plotting['edge_color']) * 255)
else:
pen = None
adj = _get_coords(G, edge_list=True)
g = qtg.GraphItem(pos=G.coords, adj=adj, pen=pen,
size=vertex_size/10)
view.addItem(g)
global _qtg_windows
_qtg_windows.append(window)
elif G.coords.shape[1] == 3:
if not QtGui.QApplication.instance():
QtGui.QApplication([]) # We want only one application.
widget = gl.GLViewWidget()
widget.opts['distance'] = 10
widget.show()
widget.setWindowTitle(plot_name)
if show_edges:
x, y, z = _get_coords(G)
pos = np.stack((x, y, z), axis=1)
g = gl.GLLinePlotItem(pos=pos, mode='lines',
color=G.plotting['edge_color'])
widget.addItem(g)
gp = gl.GLScatterPlotItem(pos=G.coords, size=vertex_size/3,
color=G.plotting['vertex_color'])
widget.addItem(gp)
global _qtg_widgets
_qtg_widgets.append(widget)
[docs]@_plt_handle_figure
def plot_filter(filters, npoints=1000, line_width=4, x_width=3,
x_size=10, plot_eigenvalues=None, show_sum=None, ax=None):
r"""
Plot the spectral response of a filter bank, a set of graph filters.
Parameters
----------
filters : Filter
Filter bank to plot.
npoints : int
Number of point where the filters are evaluated.
line_width : int
Width of the filters plots.
x_width : int
Width of the X marks representing the eigenvalues.
x_size : int
Size of the X marks representing the eigenvalues.
plot_eigenvalues : boolean
To plot black X marks at all eigenvalues of the graph. You need to
compute the Fourier basis to use this option. By default the
eigenvalues are plot if they are contained in the Graph.
show_sum : boolean
To plot an extra line showing the sum of the squared magnitudes
of the filters (default True if there is multiple filters).
plot_name : string
name of the plot
save_as : str
Whether to save the plot as save_as.png and save_as.pdf. Shown in a
window if None (default). Only available with the matplotlib backend.
ax : matplotlib.axes
Axes where to draw the graph. Optional, created if not passed. Only
available with the matplotlib backend.
Examples
--------
>>> from pygsp import plotting
>>> G = graphs.Logo()
>>> mh = filters.MexicanHat(G)
>>> plotting.plot_filter(mh)
"""
G = filters.G
if plot_eigenvalues is None:
plot_eigenvalues = hasattr(G, '_e')
if show_sum is None:
show_sum = filters.Nf > 1
if plot_eigenvalues:
for e in G.e:
ax.axvline(x=e, color=[0.9]*3, linewidth=1)
x = np.linspace(0, G.lmax, npoints)
y = filters.evaluate(x).T
ax.plot(x, y, linewidth=line_width)
# TODO: plot highlighted eigenvalues
if show_sum:
ax.plot(x, np.sum(y**2, 1), 'k', linewidth=line_width)
ax.set_xlabel("$\lambda$: laplacian's eigenvalues / graph frequencies")
ax.set_ylabel('$\hat{g}(\lambda)$: filter response')
[docs]def plot_signal(G, signal, backend=None, **kwargs):
r"""
Plot a signal on top of a graph.
Parameters
----------
G : Graph
Graph to plot a signal on top.
signal : array of int
Signal to plot. Signal length should be equal to the number of nodes.
show_edges : bool
True to draw edges, false to only draw vertices.
Default True if less than 10,000 edges to draw.
Note that drawing a large number of edges might be particularly slow.
cp : list of int
NOT IMPLEMENTED. Camera position when plotting a 3D graph.
vertex_size : float
Size of circle representing each node.
highlight : iterable
List of indices of vertices to be highlighted.
Useful to e.g. show where a filter was localized.
Only available with the matplotlib backend.
colorbar : bool
Whether to plot a colorbar indicating the signal's amplitude.
Only available with the matplotlib backend.
limits : [vmin, vmax]
Maps colors from vmin to vmax.
Defaults to signal minimum and maximum value.
Only available with the matplotlib backend.
bar : boolean
NOT IMPLEMENTED. Signal values are displayed using colors when False,
and bars when True (default False).
bar_width : int
NOT IMPLEMENTED. Width of the bar (default 1).
backend: {'matplotlib', 'pyqtgraph'}
Defines the drawing backend to use. Defaults to :data:`BACKEND`.
plot_name : string
Name of the plot.
save_as : str
Whether to save the plot as save_as.png and save_as.pdf. Shown in a
window if None (default). Only available with the matplotlib backend.
ax : matplotlib.axes
Axes where to draw the graph. Optional, created if not passed. Only
available with the matplotlib backend.
Examples
--------
>>> from pygsp import plotting
>>> G = graphs.Grid2d(4)
>>> signal = np.sin((np.arange(16) * 2*np.pi/16))
>>> plotting.plot_signal(G, signal)
"""
if not hasattr(G, 'coords'):
raise AttributeError('Graph has no coordinate set. '
'Please run G.set_coordinates() first.')
check_2d_3d = (G.coords.ndim != 2) or (G.coords.shape[1] not in [2, 3])
if G.coords.ndim != 1 and check_2d_3d:
raise AttributeError('Coordinates should be in 1D, 2D or 3D space.')
signal = signal.squeeze()
if G.coords.ndim == 2 and signal.ndim != 1:
raise ValueError('Can plot only one signal (not {}) with {}D '
'coordinates.'.format(signal.shape[1],
G.coords.shape[1]))
if signal.shape[0] != G.N:
raise ValueError('Signal length is {}, should be '
'G.N = {}.'.format(signal.shape[0], G.N))
if np.sum(np.abs(signal.imag)) > 1e-10:
raise ValueError("Can't display complex signal.")
kwargs['show_edges'] = kwargs.pop('show_edges', G.Ne < 10e3)
default = G.plotting['vertex_size']
kwargs['vertex_size'] = kwargs.pop('vertex_size', default)
plot_name = u'{}\nG.N={} nodes, G.Ne={} edges'.format(G.gtype, G.N, G.Ne)
kwargs['plot_name'] = kwargs.pop('plot_name', plot_name)
limits = [1.05*signal.min(), 1.05*signal.max()]
kwargs['limits'] = kwargs.pop('limits', limits)
if backend is None:
backend = BACKEND
G = _handle_directed(G)
if backend == 'pyqtgraph' and _QTG_IMPORT:
_qtg_plot_signal(G, signal, **kwargs)
elif backend == 'matplotlib' and _PLT_IMPORT:
_plt_plot_signal(G, signal, **kwargs)
else:
raise ValueError('The {} backend is not available.'.format(backend))
@_plt_handle_figure
def _plt_plot_signal(G, signal, show_edges, limits, ax,
vertex_size, highlight=[], colorbar=True):
if show_edges:
if G.is_directed():
raise NotImplementedError
else:
if G.coords.ndim == 1:
pass
elif G.coords.shape[1] == 2:
x, y = _get_coords(G)
ax.plot(x, y, linewidth=G.plotting['edge_width'],
color=G.plotting['edge_color'],
linestyle=G.plotting['edge_style'],
zorder=1)
elif G.coords.shape[1] == 3:
# TODO: very dirty. Cannot we prepare a set of lines?
x, y, z = _get_coords(G)
for i in range(0, x.size, 2):
x2, y2, z2 = x[i:i+2], y[i:i+2], z[i:i+2]
ax.plot(x2, y2, z2, linewidth=G.plotting['edge_width'],
color=G.plotting['edge_color'],
linestyle=G.plotting['edge_style'],
zorder=1)
try:
iter(highlight)
except TypeError:
highlight = [highlight]
coords_hl = G.coords[highlight]
if G.coords.ndim == 1:
ax.plot(G.coords, signal)
ax.set_ylim(limits)
for coord_hl in coords_hl:
ax.axvline(x=coord_hl, color='C1', linewidth=2)
elif G.coords.shape[1] == 2:
sc = ax.scatter(G.coords[:, 0], G.coords[:, 1],
s=vertex_size, c=signal, zorder=2,
vmin=limits[0], vmax=limits[1])
ax.scatter(coords_hl[:, 0], coords_hl[:, 1],
s=2*vertex_size, zorder=3,
marker='o', c='None', edgecolors='C1', linewidths=2)
elif G.coords.shape[1] == 3:
sc = ax.scatter(G.coords[:, 0], G.coords[:, 1], G.coords[:, 2],
s=vertex_size, c=signal, zorder=2,
vmin=limits[0], vmax=limits[1])
ax.scatter(coords_hl[:, 0], coords_hl[:, 1], coords_hl[:, 2],
s=2*vertex_size, zorder=3,
marker='o', c='None', edgecolors='C1', linewidths=2)
try:
ax.view_init(elev=G.plotting['elevation'],
azim=G.plotting['azimuth'])
ax.dist = G.plotting['distance']
except KeyError:
pass
if G.coords.ndim != 1 and colorbar:
plt.colorbar(sc, ax=ax)
def _qtg_plot_signal(G, signal, show_edges, plot_name, vertex_size, limits):
if G.coords.shape[1] == 2:
window = qtg.GraphicsWindow(plot_name)
view = window.addViewBox()
elif G.coords.shape[1] == 3:
if not QtGui.QApplication.instance():
QtGui.QApplication([]) # We want only one application.
widget = gl.GLViewWidget()
widget.opts['distance'] = 10
widget.show()
widget.setWindowTitle(plot_name)
if show_edges:
if G.is_directed():
raise NotImplementedError
else:
if G.coords.shape[1] == 2:
adj = _get_coords(G, edge_list=True)
pen = tuple(np.array(G.plotting['edge_color']) * 255)
g = qtg.GraphItem(pos=G.coords, adj=adj, symbolBrush=None,
symbolPen=None, pen=pen)
view.addItem(g)
elif G.coords.shape[1] == 3:
x, y, z = _get_coords(G)
pos = np.stack((x, y, z), axis=1)
g = gl.GLLinePlotItem(pos=pos, mode='lines',
color=G.plotting['edge_color'])
widget.addItem(g)
pos = [1, 8, 24, 40, 56, 64]
color = np.array([[0, 0, 143, 255], [0, 0, 255, 255], [0, 255, 255, 255],
[255, 255, 0, 255], [255, 0, 0, 255], [128, 0, 0, 255]])
cmap = qtg.ColorMap(pos, color)
signal = 1 + 63 * (signal - limits[0]) / limits[1] - limits[0]
if G.coords.shape[1] == 2:
gp = qtg.ScatterPlotItem(G.coords[:, 0],
G.coords[:, 1],
size=vertex_size/10,
brush=cmap.map(signal, 'qcolor'))
view.addItem(gp)
if G.coords.shape[1] == 3:
gp = gl.GLScatterPlotItem(pos=G.coords,
size=vertex_size/3,
color=cmap.map(signal, 'float'))
widget.addItem(gp)
if G.coords.shape[1] == 2:
global _qtg_windows
_qtg_windows.append(window)
elif G.coords.shape[1] == 3:
global _qtg_widgets
_qtg_widgets.append(widget)
[docs]def plot_spectrogram(G, node_idx=None):
r"""
Plot the spectrogram of the given graph.
Parameters
----------
G : Graph
Graph to analyse.
node_idx : ndarray
Order to sort the nodes in the spectrogram
Examples
--------
>>> from pygsp import plotting
>>> G = graphs.Ring(15)
>>> plotting.plot_spectrogram(G)
"""
from pygsp import features
if not _QTG_IMPORT:
raise NotImplementedError('You need pyqtgraph to plot the spectrogram '
'at the moment. Please install and retry.')
if not hasattr(G, 'spectr'):
features.compute_spectrogram(G)
M = G.spectr.shape[1]
spectr = G.spectr[node_idx, :] if node_idx is not None else G.spectr
spectr = np.ravel(spectr)
min_spec, max_spec = spectr.min(), spectr.max()
pos = np.array([0., 0.25, 0.5, 0.75, 1.])
color = [[20, 133, 212, 255], [53, 42, 135, 255], [48, 174, 170, 255],
[210, 184, 87, 255], [249, 251, 14, 255]]
color = np.array(color, dtype=np.ubyte)
cmap = qtg.ColorMap(pos, color)
spectr = (spectr.astype(float) - min_spec) / (max_spec - min_spec)
w = qtg.GraphicsWindow()
w.setWindowTitle("Spectrogram of {}".format(G.gtype))
label = 'frequencies {}:{:.2f}:{:.2f}'.format(0, G.lmax/M, G.lmax)
v = w.addPlot(labels={'bottom': 'nodes',
'left': label})
v.setAspectLocked()
spi = qtg.ScatterPlotItem(np.repeat(np.arange(G.N), M),
np.ravel(np.tile(np.arange(M), (1, G.N))),
pxMode=False,
symbol='s',
size=1,
brush=cmap.map(spectr, 'qcolor'))
v.addItem(spi)
global _qtg_windows
_qtg_windows.append(w)
def _get_coords(G, edge_list=False):
v_in, v_out, _ = G.get_edge_list()
if edge_list:
return np.stack((v_in, v_out), axis=1)
coords = [np.stack((G.coords[v_in, d], G.coords[v_out, d]), axis=0)
for d in range(G.coords.shape[1])]
if G.coords.shape[1] == 2:
return coords
elif G.coords.shape[1] == 3:
return [coord.reshape(-1, order='F') for coord in coords]
def _handle_directed(G):
# FIXME: plot edge direction. For now we just symmetrize the weight matrix.
if not G.is_directed():
return G
else:
from pygsp import graphs
G2 = graphs.Graph(utils.symmetrize(G.W))
G2.coords = G.coords
G2.plotting = G.plotting
return G2