# -*- coding: utf-8 -*-
r"""
The :mod:`pygsp.plotting` module implements functionality to plot PyGSP objects
with a `pyqtgraph <https://www.pyqtgraph.org>`_ or `matplotlib
<https://matplotlib.org>`_ drawing backend (which can be controlled by the
:data:`BACKEND` constant or individually for each plotting call).
Most users won't use this module directly.
Graphs (from :mod:`pygsp.graphs`) are to be plotted with
:meth:`pygsp.graphs.Graph.plot` and
:meth:`pygsp.graphs.Graph.plot_spectrogram`.
Filters (from :mod:`pygsp.filters`) are to be plotted with
:meth:`pygsp.filters.Filter.plot`.
.. data:: BACKEND
The default drawing backend to use if none are provided to the plotting
functions. Should be either ``'matplotlib'`` or ``'pyqtgraph'``. In general
pyqtgraph is better for interactive exploration while matplotlib is better
at generating figures to be included in papers or elsewhere.
"""
from __future__ import division
import functools
import numpy as np
from pygsp import utils
_logger = utils.build_logger(__name__)
BACKEND = 'matplotlib'
_qtg_widgets = []
_plt_figures = []
def _import_plt():
try:
import matplotlib as mpl
from matplotlib import pyplot as plt
from mpl_toolkits import mplot3d
except Exception as e:
raise ImportError('Cannot import matplotlib. Choose another backend '
'or try to install it with '
'pip (or conda) install matplotlib. '
'Original exception: {}'.format(e))
return mpl, plt, mplot3d
def _import_qtg():
try:
import pyqtgraph as qtg
import pyqtgraph.opengl as gl
from pyqtgraph.Qt import QtGui
except Exception as e:
raise ImportError('Cannot import pyqtgraph. Choose another backend '
'or try to install it with '
'pip (or conda) install pyqtgraph. You will also '
'need PyQt5 (or PySide) and PyOpenGL. '
'Original exception: {}'.format(e))
return qtg, gl, QtGui
def _plt_handle_figure(plot):
r"""Handle the common work (creating an axis if not given, setting the
title) of all matplotlib plot commands."""
# Preserve documentation of plot.
@functools.wraps(plot)
def inner(obj, **kwargs):
# Create a figure and an axis if none were passed.
if kwargs['ax'] is None:
_, plt, _ = _import_plt()
fig = plt.figure()
global _plt_figures
_plt_figures.append(fig)
if (hasattr(obj, 'coords') and obj.coords.ndim == 2 and
obj.coords.shape[1] == 3):
kwargs['ax'] = fig.add_subplot(111, projection='3d')
else:
kwargs['ax'] = fig.add_subplot(111)
title = kwargs.pop('title')
plot(obj, **kwargs)
kwargs['ax'].set_title(title)
try:
fig.show(warn=False)
except NameError:
# No figure created, an axis was passed.
pass
return kwargs['ax'].figure, kwargs['ax']
return inner
[docs]def close_all():
r"""Close all opened windows."""
global _qtg_widgets
for widget in _qtg_widgets:
widget.close()
_qtg_widgets = []
global _plt_figures
for fig in _plt_figures:
_, plt, _ = _import_plt()
plt.close(fig)
_plt_figures = []
[docs]def show(*args, **kwargs):
r"""Show created figures, alias to ``plt.show()``.
By default, showing plots does not block the prompt.
Calling this function will block execution.
"""
_, plt, _ = _import_plt()
plt.show(*args, **kwargs)
[docs]def close(*args, **kwargs):
r"""Close last created figure, alias to ``plt.close()``."""
_, plt, _ = _import_plt()
plt.close(*args, **kwargs)
def _qtg_plot_graph(G, edges, vertex_size, title):
qtg, gl, QtGui = _import_qtg()
if G.coords.shape[1] == 2:
widget = qtg.GraphicsLayoutWidget()
view = widget.addViewBox()
view.setAspectLocked()
if edges:
pen = tuple(np.array(G.plotting['edge_color']) * 255)
else:
pen = None
adj = _get_coords(G, edge_list=True)
g = qtg.GraphItem(pos=G.coords, adj=adj, pen=pen,
size=vertex_size/10)
view.addItem(g)
elif G.coords.shape[1] == 3:
if not QtGui.QApplication.instance():
QtGui.QApplication([]) # We want only one application.
widget = gl.GLViewWidget()
widget.opts['distance'] = 10
if edges:
x, y, z = _get_coords(G)
pos = np.stack((x, y, z), axis=1)
g = gl.GLLinePlotItem(pos=pos, mode='lines',
color=G.plotting['edge_color'])
widget.addItem(g)
gp = gl.GLScatterPlotItem(pos=G.coords, size=vertex_size/3,
color=G.plotting['vertex_color'])
widget.addItem(gp)
widget.setWindowTitle(title)
widget.show()
global _qtg_widgets
_qtg_widgets.append(widget)
def _plot_filter(filters, n, eigenvalues, sum, labels, title, ax, **kwargs):
r"""Plot the spectral response of a filter bank.
Parameters
----------
n : int
Number of points where the filters are evaluated.
eigenvalues : boolean
Whether to show the eigenvalues of the graph Laplacian.
The eigenvalues should have been computed with
:meth:`~pygsp.graphs.Graph.compute_fourier_basis`.
By default, the eigenvalues are shown if they are available.
sum : boolean
Whether to plot the sum of the squared magnitudes of the filters.
Default False if there is only one filter in the bank, True otherwise.
labels : boolean
Whether to label the filters.
Default False if there is only one filter in the bank, True otherwise.
title : str
Title of the figure.
ax : :class:`matplotlib.axes.Axes`
Axes where to draw the graph. Optional, created if not passed.
kwargs : dict
Additional parameters passed to the matplotlib plot function.
Useful for example to change the linewidth, linestyle, or set a label.
Returns
-------
fig : :class:`matplotlib.figure.Figure`
The figure the plot belongs to. Only with the matplotlib backend.
ax : :class:`matplotlib.axes.Axes`
The axes the plot belongs to. Only with the matplotlib backend.
Notes
-----
This function is only implemented with the matplotlib backend.
Examples
--------
>>> import matplotlib
>>> G = graphs.Logo()
>>> mh = filters.MexicanHat(G)
>>> fig, ax = mh.plot()
"""
if eigenvalues is None:
eigenvalues = (filters.G._e is not None)
if sum is None:
sum = (filters.n_filters > 1)
if labels is None:
labels = (filters.n_filters > 1)
if title is None:
title = repr(filters)
return _plt_plot_filter(filters, n=n, eigenvalues=eigenvalues, sum=sum,
labels=labels, title=title, ax=ax, **kwargs)
@_plt_handle_figure
def _plt_plot_filter(filters, n, eigenvalues, sum, labels, ax, **kwargs):
x = np.linspace(0, filters.G.lmax, n)
params = dict(alpha=0.5)
params.update(kwargs)
if eigenvalues:
# Evaluate the filter bank at the eigenvalues to avoid plotting
# artifacts, for example when deltas are centered on the eigenvalues.
x = np.sort(np.concatenate([x, filters.G.e]))
y = filters.evaluate(x).T
lines = ax.plot(x, y, **params)
# TODO: plot highlighted eigenvalues
if sum:
line_sum, = ax.plot(x, np.sum(y**2, 1), 'k', **kwargs)
if labels:
for i, line in enumerate(lines):
line.set_label(fr'$g_{{{i}}}(\lambda)$')
if sum:
line_sum.set_label(fr'$\sum_i g_i^2(\lambda)$')
ax.legend()
if eigenvalues:
segs = np.empty((len(filters.G.e), 2, 2))
segs[:, 0, 0] = segs[:, 1, 0] = filters.G.e
segs[:, :, 1] = [0, 1]
mpl, _, _ = _import_plt()
ax.add_collection(mpl.collections.LineCollection(
segs, transform=ax.get_xaxis_transform(), zorder=0,
color=[0.9]*3, linewidth=1, label='eigenvalues')
)
# Plot dots where the evaluation matters.
y = filters.evaluate(filters.G.e).T
params.pop('label', None)
for i in range(y.shape[1]):
params.update(color=lines[i].get_color())
ax.plot(filters.G.e, y[:, i], '.', **params)
if sum:
params.update(color=line_sum.get_color())
ax.plot(filters.G.e, np.sum(y**2, 1), '.', **params)
ax.set_xlabel(r"laplacian's eigenvalues (graph frequencies) $\lambda$")
ax.set_ylabel(r'filter response $g(\lambda)$')
def _plot_graph(G, vertex_color, vertex_size, highlight,
edges, edge_color, edge_width,
indices, colorbar, limits, ax, title, backend):
r"""Plot a graph with signals as color or vertex size.
Parameters
----------
vertex_color : array_like or color
Signal to plot as vertex color (length is the number of vertices).
If None, vertex color is set to `graph.plotting['vertex_color']`.
Alternatively, a color can be set in any format accepted by matplotlib.
Each vertex color can by specified by an RGB(A) array of dimension
`n_vertices` x 3 (or 4).
vertex_size : array_like or int
Signal to plot as vertex size (length is the number of vertices).
Vertex size ranges from 0.5 to 2 times `graph.plotting['vertex_size']`.
If None, vertex size is set to `graph.plotting['vertex_size']`.
Alternatively, a size can be passed as an integer.
The pyqtgraph backend only accepts an integer size.
highlight : iterable
List of indices of vertices to be highlighted.
Useful for example to show where a filter was localized.
Only available with the matplotlib backend.
edges : bool
Whether to draw edges in addition to vertices.
Default to True if less than 10,000 edges to draw.
Note that drawing many edges can be slow.
edge_color : array_like or color
Signal to plot as edge color (length is the number of edges).
Edge color is given by `graph.plotting['edge_color']` and transparency
ranges from 0.2 to 0.9.
If None, edge color is set to `graph.plotting['edge_color']`.
Alternatively, a color can be set in any format accepted by matplotlib.
Each edge color can by specified by an RGB(A) array of dimension
`n_edges` x 3 (or 4).
Only available with the matplotlib backend.
edge_width : array_like or int
Signal to plot as edge width (length is the number of edges).
Edge width ranges from 0.5 to 2 times `graph.plotting['edge_width']`.
If None, edge width is set to `graph.plotting['edge_width']`.
Alternatively, a width can be passed as an integer.
Only available with the matplotlib backend.
indices : bool
Whether to print the node indices (in the adjacency / Laplacian matrix
and signal vectors) on top of each node.
Useful to locate a node of interest.
Only available with the matplotlib backend.
colorbar : bool
Whether to plot a colorbar indicating the signal's amplitude.
Only available with the matplotlib backend.
limits : [vmin, vmax]
Map colors from vmin to vmax.
Defaults to signal minimum and maximum value.
Only available with the matplotlib backend.
ax : :class:`matplotlib.axes.Axes`
Axes where to draw the graph. Optional, created if not passed.
Only available with the matplotlib backend.
title : str
Title of the figure.
backend: {'matplotlib', 'pyqtgraph', None}
Defines the drawing backend to use.
Defaults to :data:`pygsp.plotting.BACKEND`.
Returns
-------
fig : :class:`matplotlib.figure.Figure`
The figure the plot belongs to. Only with the matplotlib backend.
ax : :class:`matplotlib.axes.Axes`
The axes the plot belongs to. Only with the matplotlib backend.
Notes
-----
The orientation of directed edges is not shown. If edges exist in both
directions, they will be drawn on top of each other.
Examples
--------
>>> import matplotlib
>>> graph = graphs.Sensor(20, seed=42)
>>> graph.compute_fourier_basis(n_eigenvectors=4)
>>> _, _, weights = graph.get_edge_list()
>>> fig, ax = graph.plot(graph.U[:, 1], vertex_size=graph.dw,
... edge_color=weights)
>>> graph.plotting['vertex_size'] = 300
>>> graph.plotting['edge_width'] = 5
>>> graph.plotting['edge_style'] = '--'
>>> fig, ax = graph.plot(edge_width=weights, edge_color=(0, .8, .8, .5),
... vertex_color='black')
>>> fig, ax = graph.plot(vertex_size=graph.dw, indices=True,
... highlight=[17, 3, 16], edges=False)
"""
if not hasattr(G, 'coords') or G.coords is None:
raise AttributeError('Graph has no coordinate set. '
'Please run G.set_coordinates() first.')
check_2d_3d = (G.coords.ndim != 2) or (G.coords.shape[1] not in [2, 3])
if G.coords.ndim != 1 and check_2d_3d:
raise AttributeError('Coordinates should be in 1D, 2D or 3D space.')
if G.coords.shape[0] != G.N:
raise AttributeError('Graph needs G.N = {} coordinates.'.format(G.N))
if backend is None:
backend = BACKEND
def check_shape(signal, name, length, many=False):
if (signal.ndim == 0) or (signal.shape[0] != length):
txt = '{}: signal should have length {}.'
txt = txt.format(name, length)
raise ValueError(txt)
if (not many) and (signal.ndim != 1):
txt = '{}: can plot only one signal (not {}).'
txt = txt.format(name, signal.shape[1])
raise ValueError(txt)
def normalize(x):
"""Scale values in [intercept, 1]. Return 0.5 if constant.
Set intercept value in G.plotting["normalize_intercept"]
with value in [0, 1], default is .25.
"""
ptp = x.ptp()
if ptp == 0:
return np.full(x.shape, 0.5)
else:
intercept = G.plotting['normalize_intercept']
return (1. - intercept) * (x - x.min()) / ptp + intercept
def is_color(color):
if backend == 'matplotlib':
mpl, _, _ = _import_plt()
if mpl.colors.is_color_like(color):
return True # single color
try:
return all(map(mpl.colors.is_color_like, color)) # color list
except TypeError:
return False # e.g., color is an int
else:
return False # No support for pyqtgraph (yet).
if vertex_color is None:
limits = [0, 0]
colorbar = False
if backend == 'matplotlib':
vertex_color = (G.plotting['vertex_color'],)
elif is_color(vertex_color):
limits = [0, 0]
colorbar = False
else:
vertex_color = np.asanyarray(vertex_color).squeeze()
check_shape(vertex_color, 'Vertex color', G.n_vertices,
many=(G.coords.ndim == 1))
if vertex_size is None:
vertex_size = G.plotting['vertex_size']
elif not np.isscalar(vertex_size):
vertex_size = np.asanyarray(vertex_size).squeeze()
check_shape(vertex_size, 'Vertex size', G.n_vertices)
vertex_size = G.plotting['vertex_size'] * 4 * normalize(vertex_size)**2
if edges is None:
edges = G.Ne < 10e3
if edge_color is None:
edge_color = (G.plotting['edge_color'],)
elif not is_color(edge_color):
edge_color = np.asanyarray(edge_color).squeeze()
check_shape(edge_color, 'Edge color', G.n_edges)
edge_color = 0.9 * normalize(edge_color)
edge_color = [
np.tile(G.plotting['edge_color'][:3], [len(edge_color), 1]),
edge_color[:, np.newaxis],
]
edge_color = np.concatenate(edge_color, axis=1)
if edge_width is None:
edge_width = G.plotting['edge_width']
elif not np.isscalar(edge_width):
edge_width = np.array(edge_width).squeeze()
check_shape(edge_width, 'Edge width', G.n_edges)
edge_width = G.plotting['edge_width'] * 2 * normalize(edge_width)
if limits is None:
limits = [1.05*vertex_color.min(), 1.05*vertex_color.max()]
if title is None:
title = G.__repr__(limit=4)
if backend == 'pyqtgraph':
if vertex_color is None:
_qtg_plot_graph(G, edges=edges, vertex_size=vertex_size,
title=title)
else:
_qtg_plot_signal(G, signal=vertex_color, vertex_size=vertex_size,
edges=edges, limits=limits, title=title)
elif backend == 'matplotlib':
return _plt_plot_graph(G, vertex_color=vertex_color,
vertex_size=vertex_size, highlight=highlight,
edges=edges, indices=indices, colorbar=colorbar,
edge_color=edge_color, edge_width=edge_width,
limits=limits, ax=ax, title=title)
else:
raise ValueError('Unknown backend {}.'.format(backend))
@_plt_handle_figure
def _plt_plot_graph(G, vertex_color, vertex_size, highlight,
edges, edge_color, edge_width,
indices, colorbar, limits, ax):
mpl, plt, mplot3d = _import_plt()
if edges and (G.coords.ndim != 1): # No edges for 1D plots.
sources, targets, _ = G.get_edge_list()
edges = [
G.coords[sources],
G.coords[targets],
]
edges = np.stack(edges, axis=1)
if G.coords.shape[1] == 2:
LineCollection = mpl.collections.LineCollection
elif G.coords.shape[1] == 3:
LineCollection = mplot3d.art3d.Line3DCollection
ax.add_collection(LineCollection(
edges,
linewidths=edge_width,
colors=edge_color,
linestyles=G.plotting['edge_style'],
zorder=1,
))
try:
iter(highlight)
except TypeError:
highlight = [highlight]
coords_hl = G.coords[highlight]
if G.coords.ndim == 1:
ax.plot(G.coords, vertex_color, alpha=0.5)
ax.set_ylim(limits)
for coord_hl in coords_hl:
ax.axvline(x=coord_hl, color=G.plotting['highlight_color'],
linewidth=2)
else:
sc = ax.scatter(*G.coords.T,
c=vertex_color, s=vertex_size,
marker='o', linewidths=0, alpha=0.5, zorder=2,
vmin=limits[0], vmax=limits[1])
if np.isscalar(vertex_size):
size_hl = vertex_size
else:
size_hl = vertex_size[highlight]
ax.scatter(*coords_hl.T,
s=2*size_hl, zorder=3,
marker='o', c='None',
edgecolors=G.plotting['highlight_color'], linewidths=2)
if G.coords.shape[1] == 3:
try:
ax.view_init(elev=G.plotting['elevation'],
azim=G.plotting['azimuth'])
ax.dist = G.plotting['distance']
except KeyError:
pass
if G.coords.ndim != 1 and colorbar:
plt.colorbar(sc, ax=ax)
if indices:
for node in range(G.N):
ax.text(*tuple(G.coords[node]), # accomodate 2D and 3D
s=node,
color='white',
horizontalalignment='center',
verticalalignment='center')
def _qtg_plot_signal(G, signal, edges, vertex_size, limits, title):
qtg, gl, QtGui = _import_qtg()
if G.coords.shape[1] == 2:
widget = qtg.GraphicsLayoutWidget()
view = widget.addViewBox()
elif G.coords.shape[1] == 3:
if not QtGui.QApplication.instance():
QtGui.QApplication([]) # We want only one application.
widget = gl.GLViewWidget()
widget.opts['distance'] = 10
if edges:
if G.coords.shape[1] == 2:
adj = _get_coords(G, edge_list=True)
pen = tuple(np.array(G.plotting['edge_color']) * 255)
g = qtg.GraphItem(pos=G.coords, adj=adj, symbolBrush=None,
symbolPen=None, pen=pen)
view.addItem(g)
elif G.coords.shape[1] == 3:
x, y, z = _get_coords(G)
pos = np.stack((x, y, z), axis=1)
g = gl.GLLinePlotItem(pos=pos, mode='lines',
color=G.plotting['edge_color'])
widget.addItem(g)
pos = [1, 8, 24, 40, 56, 64]
color = np.array([[0, 0, 143, 255], [0, 0, 255, 255], [0, 255, 255, 255],
[255, 255, 0, 255], [255, 0, 0, 255], [128, 0, 0, 255]])
cmap = qtg.ColorMap(pos, color)
signal = 1 + 63 * (signal - limits[0]) / limits[1] - limits[0]
if G.coords.shape[1] == 2:
gp = qtg.ScatterPlotItem(G.coords[:, 0],
G.coords[:, 1],
size=vertex_size/10,
brush=cmap.map(signal, 'qcolor'))
view.addItem(gp)
if G.coords.shape[1] == 3:
gp = gl.GLScatterPlotItem(pos=G.coords,
size=vertex_size/3,
color=cmap.map(signal, 'float'))
widget.addItem(gp)
widget.setWindowTitle(title)
widget.show()
global _qtg_widgets
_qtg_widgets.append(widget)
def _plot_spectrogram(G, node_idx):
r"""Plot the graph's spectrogram.
Parameters
----------
node_idx : ndarray
Order to sort the nodes in the spectrogram.
By default, does not reorder the nodes.
Notes
-----
This function is only implemented for the pyqtgraph backend at the moment.
Examples
--------
>>> G = graphs.Ring(15)
>>> G.plot_spectrogram()
"""
from pygsp import features
qtg, _, _ = _import_qtg()
if not hasattr(G, 'spectr'):
features.compute_spectrogram(G)
M = G.spectr.shape[1]
spectr = G.spectr[node_idx, :] if node_idx is not None else G.spectr
spectr = np.ravel(spectr)
min_spec, max_spec = spectr.min(), spectr.max()
pos = np.array([0., 0.25, 0.5, 0.75, 1.])
color = [[20, 133, 212, 255], [53, 42, 135, 255], [48, 174, 170, 255],
[210, 184, 87, 255], [249, 251, 14, 255]]
color = np.array(color, dtype=np.ubyte)
cmap = qtg.ColorMap(pos, color)
spectr = (spectr.astype(float) - min_spec) / (max_spec - min_spec)
widget = qtg.GraphicsLayoutWidget()
label = 'frequencies {}:{:.2f}:{:.2f}'.format(0, G.lmax/M, G.lmax)
v = widget.addPlot(labels={'bottom': 'nodes', 'left': label})
v.setAspectLocked()
spi = qtg.ScatterPlotItem(np.repeat(np.arange(G.N), M),
np.ravel(np.tile(np.arange(M), (1, G.N))),
pxMode=False,
symbol='s',
size=1,
brush=cmap.map(spectr, 'qcolor'))
v.addItem(spi)
widget.setWindowTitle("Spectrogram of {}".format(G.__repr__(limit=4)))
widget.show()
global _qtg_widgets
_qtg_widgets.append(widget)
def _get_coords(G, edge_list=False):
sources, targets, _ = G.get_edge_list()
if edge_list:
return np.stack((sources, targets), axis=1)
coords = [np.stack((G.coords[sources, d], G.coords[targets, d]), axis=0)
for d in range(G.coords.shape[1])]
if G.coords.shape[1] == 2:
return coords
elif G.coords.shape[1] == 3:
return [coord.reshape(-1, order='F') for coord in coords]