11from collections import OrderedDict
2- from tqdm import tqdm
2+
33import matplotlib .pyplot as plt
44import numpy as np
55import pandas as pd
1010from keras .models import Model
1111from pygsp import graphs
1212from sklearn .cluster import spectral_clustering
13- from sklearn .datasets . samples_generator import make_blobs
13+ from sklearn .datasets import make_blobs
1414from sklearn .metrics .cluster import v_measure_score , homogeneity_score , completeness_score
1515from sklearn .neighbors import kneighbors_graph
16+ from spektral .layers import MinCutPool , DiffPool
1617from spektral .layers .convolutional import GraphConvSkip
1718from spektral .utils import init_logging
1819from spektral .utils .convolution import normalized_adjacency
19- from spektral .layers import MinCutPool , DiffPool
20+ from tqdm import tqdm
21+
2022from utils import citation
2123from utils .misc import sp_matrix_to_sp_tensor_value , product_dict
24+
2225np .random .seed (0 ) # for reproducibility
2326
2427PLOTS_ON = True
3639
3740# Tunables
3841tunables = OrderedDict ([
39- ('dataset' , ['cora' ]), # 'cora', 'citeseer', 'pubmed'
40- ('method' , ['mincut_pool' ]), # 'mincut_pool', 'diff_pool'
42+ ('dataset' , ['cora' ]), # 'cora', 'citeseer', 'pubmed', 'cloud', or 'synth '
43+ ('method' , ['mincut_pool' ]), # 'mincut_pool', 'diff_pool'
4144 ('H_' , [None ]),
4245 ('n_channels' , [16 ]),
4346 ('learning_rate' , [5e-4 ])
4447])
4548
4649N_RUNS = 1
47- dataset = None
4850df_out = None
4951for T in product_dict (tunables ):
5052 # Update params with current config
6365 A = sp .csr_matrix (A , dtype = np .float32 )
6466 n_clust = y .max () + 1
6567 elif P ['dataset' ] == 'cloud' :
66- G = graphs .Grid2d (N1 = 15 , N2 = 10 ) # Community(N=150, seed=0) #SwissRoll(N=400, seed=0) #Ring(N=100) #TwoMoons() #Cube(nb_pts=500) #Bunny()
68+ G = graphs .Grid2d (N1 = 15 , N2 = 10 ) # Community(N=150, seed=0) #SwissRoll(N=400, seed=0) #Ring(N=100) #TwoMoons() #Cube(nb_pts=500) #Bunny()
6769 X = G .coords .astype (np .float32 )
6870 A = G .W
6971 y = np .ones (X .shape [0 ]) # X[:,0] + X[:,1]
7072 n_clust = 5
7173 else :
72- if dataset != P ['dataset' ]:
73- dataset = P ['dataset' ]
74- A , X , _ , _ , _ , _ , _ , _ , y_ohe = citation .load_data (dataset )
75- y = np .argmax (y_ohe , axis = - 1 )
76- X = X .todense ()
77- n_clust = y .max () + 1
74+ A , X , _ , _ , _ , _ , _ , _ , y_ohe = citation .load_data (P ['dataset' ])
75+ y = np .argmax (y_ohe , axis = - 1 )
76+ X = X .todense ()
77+ n_clust = y .max () + 1
7878
7979 # Sort IDs
8080 if P ['dataset' ] != 'cloud' :
206206 plt .scatter (X [:, 0 ], X [:, 1 ], c = c )
207207 plt .title ('GNN-pool clustering' )
208208 if P ['dataset' ] == 'cloud' :
209- fig , ax = plt .subplots (1 ,1 , figsize = (3 ,3 ))
210- G .plot_signal (c , vertex_size = 30 , plot_name = '' , colorbar = False ,ax = ax )
209+ fig , ax = plt .subplots (1 , 1 , figsize = (3 , 3 ))
210+ G .plot_signal (c , vertex_size = 30 , plot_name = '' , colorbar = False , ax = ax )
211211 ax .set_xticks ([])
212212 ax .set_yticks ([])
213213 plt .tight_layout ()
214- plt .savefig ('logs/grid_mincut.pdf' , bbox_inches = 'tight' , pad_inches = 0 )
214+ plt .savefig ('logs/grid_mincut.pdf' , bbox_inches = 'tight' , pad_inches = 0 )
215215 plt .show ()
216216
217217 # Spectral clustering
220220 P ['complete_score_sc' ] = completeness_score (y , sc )
221221 P ['v_score_sc' ] = v_measure_score (y , sc )
222222
223- print ('Spectral Clust - HOMO: {:.2}, CS: {:2}, NMI: {:2}' .format (P ['homo_score_sc' ], P ['complete_score_sc' ], P ['v_score_sc' ]))
223+ print ('Spectral Clust - HOMO: {:.3f}, CS: {:.3f}, NMI: {:.3f}'
224+ .format (P ['homo_score_sc' ], P ['complete_score_sc' ], P ['v_score_sc' ]))
224225
225226 if df_out is None :
226227 df_out = pd .DataFrame ([P ])
234235 plt .title ('Spectral clustering' )
235236 plt .show ()
236237 if P ['dataset' ] == 'cloud' :
237- fig , ax = plt .subplots (1 ,1 , figsize = (3 ,3 ))
238- G .plot_signal (sc , vertex_size = 30 , plot_name = '' , colorbar = False ,ax = ax )
238+ fig , ax = plt .subplots (1 , 1 , figsize = (3 , 3 ))
239+ G .plot_signal (sc , vertex_size = 30 , plot_name = '' , colorbar = False , ax = ax )
239240 ax .set_xticks ([])
240241 ax .set_yticks ([])
241242 plt .tight_layout ()
242- plt .savefig ('logs/grid_spectral.pdf' , bbox_inches = 'tight' , pad_inches = 0 )
243- K .clear_session ()
243+ plt .savefig ('logs/grid_spectral.pdf' , bbox_inches = 'tight' , pad_inches = 0 )
244+ K .clear_session ()
0 commit comments