%   This code tests the performance of EnSC on clustering the Coil-100 
%   image, PIE face image, MNIST handwritten digit image and CovType 
%   databases. The code generates results in Table 2 of the paper

%   Chong You, Chun-guang Li, Daniel Robinson, Rene Vidal,
%   "Oracle Based Active Set Algorithm for Scalable Elastic Net Subspace
%   Clustering", CVPR 2016.

% Instructions for running the code:
% - Download code for computing clustering accuracy. Go to
% http://www.cad.zju.edu.cn/home/dengcai/Data/Clustering.html and download
% bestMap.m and Hungarian.m. Alternatively, you can use your own function
% by redefining the function evalAccuracy.m
% - Download databases.
%   = Coil-100: we use the data provided at
%   http://www.cad.zju.edu.cn/home/dengcai/Data/MLData.html
%   = PIE: (not available)
%   = MNIST: 
%     Download MNIST traning/test image/label files from 
%     http://yann.lecun.com/exdb/mnist/ and put the data in folder MNIST/
%     Download files for reading data (loadMNISTImages.m and 
%     loadMNISTLabels.m) from
%     http://ufldl.stanford.edu/wiki/index.php/MATLAB_Modules
%     Download scattering transform package ScatNet (v0.2) from
%     http://www.di.ens.fr/data/software/
%     and install.
%   = CovType: download from
%     https://kdd.ics.uci.edu/databases/covertype/covertype.html
% - Run. You can modify the parameter "databaseName" below to run for 
%   different databases.

% Copyright Chong You @ Johns Hopkins University, 2016
% chong.you1987@gmail.com

addpath('Tools')
%% Settings
databaseName = 'Coil100';

%% Load data and set parameters
if strcmpi(databaseName, 'coil100')
    addpath('Coil100')
    load COIL100.mat fea gnd
    X = double(fea');
    s = gnd;
    
    nu0 = 3;
    lambda = 0.95;
    Nsample = 100;
elseif strcmpi(databaseName, 'PIE')
    addpath('PIE')
    load PIE.mat fea gnd
    X = double(fea');
    s = gnd;
    
    nu0 = 200;
    lambda = 0.1;
    Nsample = 800;
elseif strcmpi(databaseName, 'MNIST')
    addpath('MNIST')
%     addpath('/cis/home/cyou/Project/Database/MNIST')
    if ~exist('MNIST_DATA', 'var')
        try
            % MNIST_SC_DATA is a D by N matrix. Each column contains a feature 
            % vector of a digit image and N = 70,000 (i.e. contains both training 
            % and testing)
            % MNIST_LABEL is a 1 by N vector. Each entry is the label for the
            % corresponding column in MNIST_SC_DATA.
            load MNIST_SC.mat MNIST_SC_DATA MNIST_LABEL;
        catch
            fprintf('load MNIST...\n')
            MNIST_DATA = [loadMNISTImages('train-images-idx3-ubyte'), ...
                        loadMNISTImages('t10k-images.idx3-ubyte')];
            MNIST_LABEL = [loadMNISTLabels('train-labels-idx1-ubyte'); ...
                        loadMNISTLabels('t10k-labels.idx1-ubyte')];
            fprintf('scattering transform on MNIST...\n')
            MNIST_SC_DATA = SCofDigits(MNIST_DATA);
            save MNIST_SC.mat MNIST_SC_DATA MNIST_LABEL;
        end
    end
    X = dimReduction_PCA(MNIST_SC_DATA, 500);
    clear MNIST_SC_DATA;
    s = MNIST_LABEL;
    
    nu0 = 120;
    lambda = 0.95;
    Nsample = 600;
elseif strcmpi(databaseName, 'covtype')
    addpath('CovType')
%     addpath('/cis/home/cyou/Project/Database/CovType')
    DATA = csvread('covtype.data');
    s = DATA(:, 55);
    X = DATA(:, 1:54);
    X = cnormalize(X)'; % normalize features
    
    nu0 = 50; 
    lambda = 0.95;
    Nsample = 500;
end
nCluster = length(unique(s));
%% Clustering

tic;
X = cnormalize_inplace(X);
EN_solver =  @(X, y, lambda, nu) rfss( X, y, lambda / nu, (1-lambda) / nu );
% EN_solver = @(X, y, lambda, nu) CompSens_EN_Homotopy(X, y, lambda, nu);
R = ORGEN_mat_func(X, EN_solver, 'nu0', nu0, 'nu_method', 'nonzero', 'lambda', lambda, ...
                                                          'Nsample', Nsample, 'maxiter', 2, 'outflag', true); 
N = length(s);                                                      
R(1:N+1:end) = 0;
A = abs(R) + abs(R)';
groups = SpectralClustering(A, nCluster, 'Eig_Solver', 'eigs');                             
time = toc;

% Evaluation
perc = evalSSR_perc( R, s );
ssr = evalSSR_error( R, s );   
conn = evalConn( A, s);
accr  = evalAccuracy(s, groups);
% output
dataformat = 'perc = %f, ssr = %f, conn = %f, accr = %f, time = %f\n';
dataValue = [perc, ssr, conn, accr, time];
fprintf(dataformat, dataValue);

