function x = CompSens_EN_Homotopy(A, b, lambda, gamma, x0)
%COMPSENS_EN_HOMOTOPY Solving elastic net by homotopy method.
% This code solves for the following problem using Homotopy:
%   min_x lambda ||x||_1 + (1-lambda)/2 ||x||_2^2 + gamma/2 ||Ax - b||_2^2
% 
% Usage 1:
% If 'x0' is not specified then this is the canonical Homotopy method.
%
% Usage 2:
% The code is written in a general way such that it can take 
% an initial value x0 supported on the set 1:P for some P and
% x_P := x0(1:P) is a solution to
%   min_{x_P} lambda ||x_P||_1 + (1-lambda)/2 ||x_P||_2^2 + gamma/2 ||A_P * x_P - b||_2^2,
% where A_P is A(:, 1:P).
% The algorithm then performs homotopy on an epsilon in the following:
%   min_{x_P, x_Q} \lambda ||x_P||_1 + (1-lambda)/2 ||x_P||_2^2 + ...
%                     \epsilon ||x_Q||_1 + (1-\lambda)/2 ||x_Q||_2^2 + ...
%                        \gamma/2 ||A_P * x_P + A_Q * x_Q - b||_2^2,
% where Q is the complement of P, and A_Q = A(:, P+1:end). Obviously, if P is
% 0 and the algorithm falls back to traditional Homotopy for elastic net.
% This is designed to be used in conjunction with ORGEN which is
% described in
%   Chong You, Chun-guang Li, Daniel Robinson, Rene Vidal,
%   "Oracle Based Active Set Algorithm for Scalable Elastic Net Subspace
%   Clustering", CVPR 2016.
% (However, I don't have a well written documentation for this file)

% Copyright Chong You @ Johns Hopkins University, 2016
% chong.you1987@gmail.com

N = size(A, 2);

if ~exist('x0', 'var')
    x0 = zeros(N, 1);
end

pivot = find(x0 == 0, 1, 'first')-1;

% initialize
R_P = chol( A(:, 1:pivot)' * A(:, 1:pivot) + (1-lambda)/gamma * eye(pivot) );
delta = gamma * (A(:, 1:pivot) * x0(1:pivot) - b);

[epsilon, jstar] = max( abs(A(:, pivot+1:end)' * delta) );
S = [1:pivot, jstar + pivot];

Pc = false(N, 1);
Qc = [false(pivot, 1); true(N - pivot, 1)];
Qc(jstar + pivot) = false;

z = [zeros(pivot, 1); -sign(A(:, jstar + pivot)' * delta)];
x_S = [x0(1:pivot); 0];
R_S = cholinsert(R_P, A(:, jstar + pivot), A(:, 1:pivot), (1-lambda)/gamma);

for ii = 1:10^3
    % x_S(\epsilon) = x_S + u_S * \epsilon, u_S = (As'As + (1-lambda)/gamma I) \ [0; z_Qs]/gamma
    % delta(\epsilon) = delta + v_S * \epsilon, v_S = gamma * As * u_S
    u_S = R_S \ (R_S' \ z);
    v_S = A(:, S) * u_S;
    u_S = u_S / gamma;
    
    % find critical point
    % - minus
    deps_minus_vec = (-x_S ./ u_S);
    deps_minus_vec( deps_minus_vec <= eps ) = inf;
    [deps_minus, deps_minus_index] = min(deps_minus_vec);
    % Plus
    coh_bias = A' * delta;
    coh_slope = A' * v_S;
    % - P plus
    if ~any(Pc) % all point in P are actived
        deps_Pplus = inf;
    else        
        hit_sign = sign(coh_slope(Pc));
        deps_Pplus_vec = (hit_sign * lambda - coh_bias(Pc)) ./ coh_slope(Pc);
        [deps_Pplus, deps_Pplus_index] = min(deps_Pplus_vec);
    end
    % - Q plus
    hit_sign = sign(coh_bias + coh_slope * epsilon);

    deps_Qplus_vec = (hit_sign * epsilon - coh_bias) ./ (coh_slope + hit_sign);
    [deps_Qplus, deps_Qplus_index] = min(deps_Qplus_vec(Qc));
    hit_sign = hit_sign(Qc);
    % 
    [deps, deps_index] = min([deps_minus, deps_Pplus, deps_Qplus]);
    % Update
    if epsilon - deps < lambda
        break;
    end
    epsilon = epsilon - deps;
    x_S = x_S + u_S * deps;
    
    if deps_index == 1 % remove a point
        index = S(deps_minus_index) ;
        if index <= pivot
            Pc(index) = true;
        else
            Qc(index) = true;
        end
        S(deps_minus_index) = [];
        z(deps_minus_index) = [];
        R_S = choldelete(R_S, deps_minus_index);
        x_S(deps_minus_index) = [];
    else % add point
        if deps_index == 2 % add to P_S
            index = find(Pc, deps_Pplus_index);
            index = index(end);
            Pc(index) = false;
            z = [z; 0];
        else
            index = find(Qc, deps_Qplus_index);
            index = index(end);
            Qc(index) = false;
            z = [z; -hit_sign(deps_Qplus_index)];
        end
        R_S = cholinsert(R_S, A(:, index), A(:, S), (1-lambda)/gamma);
        S = [S, index];
        x_S = [x_S; 0];
    end
    delta = delta + v_S * deps;
    
%     fprintf('Support size: %d, changed entry: %d, epsilon: %f\n', length(S), index, epsilon)
end
% fprintf('Iterations: %d, final tau: %f\n', ii, epsilon)
if ii == 10000
    error('Too many iterations, not converged\n')
end

x = zeros(size(A, 2), 1);
x(S) = x_S + u_S * (epsilon - lambda);


        
