## This file is part of MLPY.
## Fisher Discriminant Analysis.

## This is an implementation of Fisher Discriminant Analysis described in:
## 'An Improved Training Algorithm for Kernel Fisher Discriminants' S. Mika,
## A. Smola, B Scholkopf. 2001.
    
## This code is written by Roberto Visintainer, <visintainer@fbk.eu> and
## Davide Albanese <albanese@fbk.eu>.
## (C) 2008 Fondazione Bruno Kessler - Via Santa Croce 77, 38100 Trento, ITALY.

## This program is free software: you can redistribute it and/or modify
## it under the terms of the GNU General Public License as published by
## the Free Software Foundation, either version 3 of the License, or
## (at your option) any later version.

## This program is distributed in the hope that it will be useful,
## but WITHOUT ANY WARRANTY; without even the implied warranty of
## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
## GNU General Public License for more details.

## You should have received a copy of the GNU General Public License
## along with this program.  If not, see <http://www.gnu.org/licenses/>.

__all__ = ['Fda']

from numpy import *
from numpy.linalg import inv
import random as rnd


def dot3(a1, M, a2):
    """Compute a1 * M * a2T
    """
    
    a1M = dot(a1, M)
    res = inner(a1M, a2)
    return res



class Fda:
    """Fisher Discriminant Analysis.
    """
    
    def __init__(self, C = 1):
        """
        Initialize Fda class.
        
        Input
        
          * *C* - [float] Regularization parameter
        """
                      
        self.__C = C
        self.__w = 'cr'
                
        self.__x     = None
        self.__y     = None
        self.__xpred = None
        
        self.__a  = None
        self.__b  = None
        self.__K  = None
        
                
    def __stdinvH(self, x, C):
        """Build matrix H and invert it.
        
        See eq. 4 at page 2.
        
        Matrix H:
        |-------------------------|
        |l(Val)       | oneK(Vet) | 
        |-------------------------|
        |oneKT(Vet)   |  M(Mat)   |                       
        |-------------------------|
        """       
        
        # Compute kernel matrix
        xT = x.transpose()
        K  = dot(x, xT)
        KT = K # (symmetric matrix)
        
        # Alloc H
        H = empty((K.shape[0] + 1, K.shape[0] + 1), dtype = float)
        
        # Compute oneK = 1T * K
        oneK = K.sum(axis = 0)
        
        # Build H
        # Compute M = (KT * K) + (C * P)
        H[1:, 1:] = dot(KT, K) + identity(K.shape[1]) * C  
        H[0, 1:] = oneK
        H[1:, 0] = oneK
        H[0, 0]  = x.shape[0]
        
        invH = inv(H)
        
        return (K, KT, invH)
    

    def __compute_a(self, x, y, KT, invH):
        """Compute a

        See eq. 8, 9 at page 3.
        """
        
        lp = y[y ==  1].shape[0]
        ln = y[y == -1].shape[0]
        
        # Compute c, A+ and A-.
        # See eq. 4 at page 2. 
        
        c = append((lp - ln), dot(KT, y))

        onep = zeros_like(y)
        onen = zeros_like(y)
        onep[y == 1 ] = 1
        onen[y == -1] = 1
                
        Ap = append(lp, dot(KT, onep))
        An = append(ln, dot(KT, onen))

        # Compute lambda
        # See eq. 9 at page 3.

        A = dot3(Ap, invH, Ap)
        B = dot3(Ap, invH, An)
        C = dot3(An, invH, Ap)
        D = dot3(An, invH, An)
        E = -(lp) + dot3(c, invH, Ap)
        F = ln + dot3(c, invH, An)
        G = -0.5 * dot3(c, invH, c)
                    
        lambdan = ( -F + ((C + B) * E / (2 * A)) ) / \
                  ( -D + ((C + B)**2  / (4 * A)) )
        lambdap = ( -E + (0.5 * (C + B) * lambdan) ) / -A

        # Compute a
        # See eq. 8 at page 3.

        lambdaAp = dot(lambdap, Ap)
        lambdaAn = dot(lambdan, An)

        a = dot(invH, (c - (lambdaAp + lambdaAn)))
                
        return a

    
    def __standard(self):
        self.__K, KT, invH = self.__stdinvH(self.__x, self.__C)
        a                  = self.__compute_a(self.__x, self.__y, KT, invH)
        
        self.__xpred = self.__x
        
        # Return b, a
        return a[0], a[1:] 
    
    
    def compute(self, x, y):
        """Compute fda model.

        Input
        
          * *x* - [2D numpy array float] (sample x feature)  training data
          * *y* - [1D numpy array integer] (two classes, 1 or -1) classes

        Output
        
          * 1
        """

        self.__x = x
        self.__y = y
               
        self.__b, self.__a = self.__standard()           
        
        return 1
        
    def predict(self, p):
        """Predict fda model on test point(s).

        Input
        
          * *p* - [1D or 2D numpy array float] test point(s)

        Output
        
          * *cl* - [integer or 1D numpy array integer] class(es) predicted
          * *self.realpred* - [float or 1D numpy array float] real valued prediction
        """
        
        
        if p.ndim == 2:
            
            # Real prediction
            pT = p.transpose()
            K  = dot(self.__xpred, pT)
            self.realpred = dot(self.__a, K) + self.__b
            
            # Prediction
            pred = zeros(p.shape[0], dtype = int)
            pred[self.realpred > 0.0] = 1
            pred[self.realpred < 0.0] = -1
            
        elif p.ndim == 1:
            
            # Real prediction
            pT = p.reshape(-1, 1)
            K  = dot(self.__xpred, pT)
            self.realpred = (dot(self.__a, K) + self.__b)[0]               
            
            # Prediction
            pred = 0.0
            if self.realpred > 0.0:
                pred = 1
            elif self.realpred < 0.0:
                pred = -1
                
        return pred
    

    def weights (self, x, y):
        """
        Return feature weights.
        
        Input
        
          * *x* - [2D numpy array float] (sample x feature) training data
          * *y* - [1D numpy array integer] (two classes, 1 or -1) classes

        Output
        
          * *fw* - [1D numpy array float] feature weights
        """

        self.compute(x,y)

     
        if self.__w == 'cr':

            n1idx = where(y ==  1)[0]
            n2idx = where(y == -1)[0]
            idx   = append(n1idx, n2idx)

            y = self.__y[idx]

            K = self.__K[idx][:, idx]
            
            target = ones((y.shape[0], y.shape[0]), dtype = int)
            target[:n1idx.shape[0], n1idx.shape[0]:] = -1
            target[n1idx.shape[0]:, :n1idx.shape[0]] = -1

            yy = trace(dot(target, target))

            w = empty(x.shape[1], dtype = float)
            
            for i in range(x.shape[1]):
                mask = dot(x[:, i].reshape(-1, 1), x[:, i].reshape(1, -1))
                newK = K - mask

                w[i] = sqrt( trace(dot(newK, newK)) * yy) / trace(dot(newK, target))


            return w
        
