Ward clustering to learn a brain parcellation from rest fMRIΒΆ

We use spatially-constrained Ward-clustering to create a set of parcels. These parcels are particularly interesting for creating a ‘compressed’ representation of the data, replacing the data in the fMRI images by mean on the parcellation.

This parcellation may be useful in a supervised learning, see for instance: A supervised clustering approach for fMRI-based inference of brain states, Michel et al, Pattern Recognition 2011.

Python source code: plot_rest_clustering.py

### Load nyu_rest dataset #####################################################

import numpy as np
from nilearn import datasets
from nilearn import input_data
from nilearn.plotting.img_plotting import plot_roi, plot_epi
nyu_dataset = datasets.fetch_nyu_rest(n_subjects=1)

# print basic information on the dataset
print('First subject anatomical nifti image (3D) is at: %s' %
      nyu_dataset.anat_anon[0])
print('First subject functional nifti image (4D) is at: %s' %
      nyu_dataset.func[0])  # 4D data

# This is resting-state data: the background has not been removed yet,
# thus we need to use mask_strategy='epi' to compute the mask from the
# EPI images
nifti_masker = input_data.NiftiMasker(memory='nilearn_cache',
                                      mask_strategy='epi', memory_level=1,
                                      standardize=False)
func_filename = nyu_dataset.func[0]
fmri_masked = nifti_masker.fit_transform(func_filename)
mask = nifti_masker.mask_img_.get_data().astype(np.bool)

### Ward ######################################################################

# Compute connectivity matrix: which voxel is connected to which
from sklearn.feature_extraction import image
shape = mask.shape
connectivity = image.grid_to_graph(n_x=shape[0], n_y=shape[1],
                                   n_z=shape[2], mask=mask)

# Computing the ward for the first time, this is long...
from sklearn.cluster import FeatureAgglomeration
# If you have scikit-learn older than 0.14, you need to import
# WardAgglomeration instead of FeatureAgglomeration
import time
start = time.time()
ward = FeatureAgglomeration(n_clusters=1000, connectivity=connectivity,
                            linkage='ward', memory='nilearn_cache')
ward.fit(fmri_masked)
print("Ward agglomeration 1000 clusters: %.2fs" % (time.time() - start))

# Compute the ward with more clusters, should be faster as we are using
# the caching mechanism
start = time.time()
ward = FeatureAgglomeration(n_clusters=2000, connectivity=connectivity,
                            linkage='ward', memory='nilearn_cache')
ward.fit(fmri_masked)
print("Ward agglomeration 2000 clusters: %.2fs" % (time.time() - start))

### Show result ###############################################################

# Unmask data
# Avoid 0 label
labels = ward.labels_ + 1
labels_img = nifti_masker.inverse_transform(labels)

from nilearn.image import mean_img
import matplotlib.pyplot as plt
mean_func_img = mean_img(func_filename)

# common cut coordinates for all plots

first_plot = plot_roi(labels_img, mean_func_img, title="Ward parcellation",
                      display_mode='xz')
# labels_img is a Nifti1Image object, it can be saved to file with the
# following code:
labels_img.to_filename('parcellation.nii')


# Display the original data
plot_epi(nifti_masker.inverse_transform(fmri_masked[0]),
         cut_coords=first_plot.cut_coords,
         title='Original (%i voxels)' % fmri_masked.shape[1],
         display_mode='xz')

# A reduced data can be create by taking the parcel-level average:
# Note that, as many objects in the scikit-learn, the ward object exposes
# a transform method that modifies input features. Here it reduces their
# dimension
fmri_reduced = ward.transform(fmri_masked)

# Display the corresponding data compressed using the parcellation
fmri_compressed = ward.inverse_transform(fmri_reduced)
compressed_img = nifti_masker.inverse_transform(fmri_compressed[0])

plot_epi(compressed_img, cut_coords=first_plot.cut_coords,
         title='Compressed representation (2000 parcels)',
         display_mode='xz')

plt.show()

Total running time of the example: 0.00 seconds ( 0 minutes 0.00 seconds)