Source code for pygsp.graphs.grid2d

# -*- coding: utf-8 -*-

import numpy as np
from scipy import sparse

from pygsp import utils
from . import Graph  # prevent circular import in Python < 3.5


[docs]class Grid2d(Graph): r"""2-dimensional grid graph. Parameters ---------- N1 : int Number of vertices along the first dimension. N2 : int Number of vertices along the second dimension (default N1). Examples -------- >>> import matplotlib.pyplot as plt >>> G = graphs.Grid2d(N1=5, N2=4) >>> fig, axes = plt.subplots(1, 2) >>> _ = axes[0].spy(G.W) >>> G.plot(ax=axes[1]) """ def __init__(self, N1=16, N2=None, **kwargs): if N2 is None: N2 = N1 N = N1 * N2 # Filling up the weight matrix this way is faster than # looping through all the grid points: diag_1 = np.ones(N - 1) diag_1[(N2 - 1)::N2] = 0 diag_2 = np.ones(N - N2) W = sparse.diags(diagonals=[diag_1, diag_2], offsets=[-1, -N2], shape=(N, N), format='csr', dtype='float') W = utils.symmetrize(W, method='tril') x = np.kron(np.ones((N1, 1)), (np.arange(N2)/float(N2)).reshape(N2, 1)) y = np.kron(np.ones((N2, 1)), np.arange(N1)/float(N1)).reshape(N, 1) y = np.sort(y, axis=0)[::-1] coords = np.concatenate((x, y), axis=1) plotting = {"limits": np.array([-1. / N2, 1 + 1. / N2, 1. / N1, 1 + 1. / N1])} super(Grid2d, self).__init__(W=W, gtype='2d-grid', coords=coords, plotting=plotting, **kwargs)