Source code for pygsp.graphs.swissroll

# -*- coding: utf-8 -*-

import numpy as np

from pygsp import utils
from . import Graph  # prevent circular import in Python < 3.5


[docs]class SwissRoll(Graph): r"""Sampled Swiss roll manifold. Parameters ---------- N : int Number of vertices (default = 400) a : int (default = 1) b : int (default = 4) dim : int (default = 3) thresh : float (default = 1e-6) s : float sigma (default = sqrt(2./N)) noise : bool Wether to add noise or not (default = False) srtype : str Swiss roll Type, possible arguments are 'uniform' or 'classic' (default = 'uniform') seed : int Seed for the random number generator (for reproducible graphs). Examples -------- >>> import matplotlib.pyplot as plt >>> G = graphs.SwissRoll(N=200, seed=42) >>> fig = plt.figure() >>> ax1 = fig.add_subplot(121) >>> ax2 = fig.add_subplot(122, projection='3d') >>> _ = ax1.spy(G.W, markersize=1) >>> G.plot(ax=ax2) """ def __init__(self, N=400, a=1, b=4, dim=3, thresh=1e-6, s=None, noise=False, srtype='uniform', seed=None, **kwargs): if s is None: s = np.sqrt(2. / N) rs = np.random.RandomState(seed) y1 = rs.rand(N) y2 = rs.rand(N) if srtype == 'uniform': tt = np.sqrt((b * b - a * a) * y1 + a * a) elif srtype == 'classic': tt = (b - a) * y1 + a tt *= np.pi if dim == 2: x = np.array((tt * np.cos(tt), tt * np.sin(tt))) elif dim == 3: x = np.array((tt * np.cos(tt), 21 * y2, tt * np.sin(tt))) if noise: x += rs.randn(*x.shape) self.x = x self.dim = dim coords = utils.rescale_center(x) dist = utils.distanz(coords) W = np.exp(-np.power(dist, 2) / (2. * s**2)) W -= np.diag(np.diag(W)) W[W < thresh] = 0 plotting = { 'vertex_size': 60, 'limits': np.array([-1, 1, -1, 1, -1, 1]), 'elevation': 15, 'azimuth': -90, 'distance': 7, } gtype = 'swiss roll {}'.format(srtype) super(SwissRoll, self).__init__(W=W, coords=coords.T, plotting=plotting, gtype=gtype, **kwargs)