Source code for pygsp.graphs.swissroll

# -*- coding: utf-8 -*-

from . import Graph
from pygsp.utils import distanz

import numpy as np
from math import sqrt, pi


[docs]class SwissRoll(Graph): r""" Create a swiss roll graph. Parameters ---------- N : int Number of vertices (default = 400) a : int (default = 1) b : int (default = 4) dim : int (default = 3) thresh : float (default = 1e-6) s : float sigma (default = sqrt(2./N)) noise : bool Wether to add noise or not (default = False) srtype : str Swiss roll Type, possible arguments are 'uniform' or 'classic' (default = 'uniform') Examples -------- >>> from pygsp import graphs >>> G = graphs.SwissRoll() """ def __init__(self, N=400, a=1, b=4, dim=3, thresh=1e-6, s=None, noise=False, srtype='uniform'): if s is None: s = sqrt(2./N) y1 = np.random.rand(N) y2 = np.random.rand(N) if srtype == 'uniform': tt = np.sqrt((b * b - a * a) * y1 + a * a) elif srtype == 'classic': tt = (b - a) * y1 + a tt *= pi if dim == 2: x = np.array((tt*np.cos(tt), tt * np.sin(tt))) elif dim == 3: x = np.array((tt*np.cos(tt), 21 * y2, tt * np.sin(tt))) if noise: x += np.random.randn(*x.shape) self.x = x self.dim = dim dist = distanz(coords) W = np.exp(-np.power(dist, 2) / (2. * s**2)) W -= np.diag(np.diag(W)) W[W < thresh] = 0 coords = self.rescale_center(x) plotting = {'limits': np.array([-1, 1, -1, 1, -1, 1])} gtype = 'swiss roll {}'.format(srtype) super(SwissRoll, self).__init__(W=W, coords=coords.T, plotting=plotting, gtype=gtype)
[docs] def rescale_center(self, x): r""" Rescaling the dataset. Rescaling the dataset, previously and mainly used in the SwissRoll graph. Parameters ---------- x : ndarray Dataset to be rescaled. Returns ------- r : ndarray Rescaled dataset. Examples -------- >>> from pygsp import utils >>> utils.dummy(0, [1, 2, 3], True) array([1, 2, 3]) """ N = x.shape[1] y = x - np.kron(np.ones((1, N)), np.mean(x, axis=1)[:, np.newaxis]) c = np.amax(y) r = y / c return r