/*******************************************************************************
 * This library is free software; you can redistribute it and/or
 * modify it under the terms of the GNU Lesser General Public
 * License as published by the Free Software Foundation; either
 * version 2.1 of the License, or (at your option) any later version.
 * 
 * This library is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
 * GNU Lesser General Public License for more details.
 * 
 * You should have received a copy of the GNU Lesser General Public
 * License along with this library; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
 * USA
 * 
 * Contact Info:
 * 	Bruce Donald
 * 	Duke University
 * 	Department of Computer Science
 * 	Levine Science Research Center (LSRC)
 * 	Durham
 * 	NC 27708-0129 
 * 	USA
 * 	brd@cs.duke.edu
 * 
 * Copyright (C) 2011 Jeffrey W. Martin and Bruce R. Donald
 * 
 * <signature of Bruce Donald>, April 2011
 * Bruce Donald, Professor of Computer Science
 ******************************************************************************/


package edu.duke.donaldLab.share.nmr;

import java.util.ArrayList;
import java.util.List;
import java.util.PriorityQueue;

import org.apache.log4j.Level;
import org.apache.log4j.Logger;

import Jama.EigenvalueDecomposition;
import Jama.Matrix;

import edu.duke.donaldLab.share.geom.Vector3;
import edu.duke.donaldLab.share.io.Logging;
import edu.duke.donaldLab.share.math.EigPair;
import edu.duke.donaldLab.share.math.Matrix3;
import edu.duke.donaldLab.share.perf.LoggingMessageListener;
import edu.duke.donaldLab.share.perf.MessageListener;
import edu.duke.donaldLab.share.perf.Progress;
import edu.duke.donaldLab.share.protein.Atom;
import edu.duke.donaldLab.share.protein.AtomAddressInternal;
import edu.duke.donaldLab.share.protein.Protein;
import edu.duke.donaldLab.share.protein.Subunit;

public class AlignmentTensor
{
	/**************************
	 *   Data Members
	 **************************/
	
	private static final Logger m_log = Logging.getLog( AlignmentTensor.class );
	
	private double m_Sxy;
	private double m_Sxz;
	private double m_Syy;
	private double m_Syz;
	private double m_Szz;
	
	private EigPair[] m_eigs;
	
	
	/**************************
	 *   Constructors
	 **************************/
	
	public AlignmentTensor( double Sxy, double Sxz, double Syy, double Syz, double Szz )
	{
		m_Sxy = Sxy;
		m_Sxz = Sxz;
		m_Syy = Syy;
		m_Syz = Syz;
		m_Szz = Szz;
		
		// construct the saupe matrix
		Matrix saupe = new Matrix( 3, 3 );
		getSaupe().toJama( saupe );
		
		// sort eigenvalues (along with eigenvectors) in order of increasing magnitude
		EigenvalueDecomposition eig = saupe.eig();
		PriorityQueue<EigPair> q = new PriorityQueue<EigPair>( 3 );
		q.add( new EigPair( eig, 0 ) );
		q.add( new EigPair( eig, 1 ) );
		q.add( new EigPair( eig, 2 ) );
		m_eigs = new EigPair[3];
		m_eigs[AlignmentTensorAxis.X.ordinal()] = q.poll();
		m_eigs[AlignmentTensorAxis.Y.ordinal()] = q.poll();
		m_eigs[AlignmentTensorAxis.Z.ordinal()] = q.poll();
	}
	

	/**************************
	 *   Accessors
	 **************************/
	
	public double getSxy( )
	{
		return m_Sxy;
	}
	
	public double getSxz( )
	{
		return m_Sxz;
	}
	
	public double getSyy( )
	{
		return m_Syy;
	}
	
	public double getSyz( )
	{
		return m_Syz;
	}
	
	public double getSzz( )
	{
		return m_Szz;
	}
	
	public double getDxx( )
	{
		return getEigenvalue( AlignmentTensorAxis.X );
	}
	
	public double getDyy( )
	{
		return getEigenvalue( AlignmentTensorAxis.Y );
	}
	
	public double getDzz( )
	{
		return getEigenvalue( AlignmentTensorAxis.Z );
	}
	
	public double getEigenvalue( AlignmentTensorAxis axis )
	{
		return m_eigs[axis.ordinal()].getEigenvalue();
	}
	
	public Vector3 getXAxis( )
	{
		return getAxis( AlignmentTensorAxis.X );
	}
	
	public Vector3 getYAxis( )
	{
		return getAxis( AlignmentTensorAxis.Y );
	}
	
	public Vector3 getZAxis( )
	{
		return getAxis( AlignmentTensorAxis.Z );
	}
	
	public Vector3 getAxis( AlignmentTensorAxis axis )
	{
		return m_eigs[axis.ordinal()].getEigenvector();
	}
	
	public double getAssymmetry( )
	{
		return ( getDxx() - getDyy() ) / getDzz();
	}
	
	public double getRhombicity( )
	{
		return getAssymmetry() * 2.0 / 3.0;
	}
	
	
	/**************************
	 *   Static Methods
	 **************************/
	
	public static AlignmentTensor compute( Subunit subunit, List<Rdc<AtomAddressInternal>> rdcs )
	{
		return compute( new Protein( new Subunit( subunit ) ), rdcs );
	}
	
	public static AlignmentTensor compute( Protein protein, List<Rdc<AtomAddressInternal>> rdcs )
	{
		// build the matrix of vector products
		Matrix A = new Matrix( rdcs.size(), 5 );
		int row = 0;
		for( Rdc<AtomAddressInternal> rdc : rdcs )
		{
			Atom fromAtom = protein.getAtom( rdc.getFrom() );
			Atom toAtom = protein.getAtom( rdc.getTo() );
			
			// get the internuclear vector
			Vector3 vec = new Vector3( toAtom.getPosition() );
			vec.subtract( fromAtom.getPosition() );
			vec.normalize();
			
			A.set( row, 0, 2.0 * vec.x * vec.y );
			A.set( row, 1, 2.0 * vec.x * vec.z );
			A.set( row, 2, vec.y * vec.y - vec.x * vec.x );
			A.set( row, 3, 2.0 * vec.y * vec.z );
			A.set( row, 4, vec.z * vec.z - vec.x * vec.x );
			
			row++;
		}
		
		// build the vector of rdc values
		Matrix b = new Matrix( rdcs.size(), 1 );
		row = 0;
		for( Rdc<AtomAddressInternal> rdc : rdcs )
		{
			b.set( row++, 0, rdc.getValue() );
		}
		
		// solve Ax = b for x
		// this implementation uses QR decomposition to find the least-squares solution
		Matrix x = A.solve( b );
		
		/* if we need it, we could use SVD explicitly
		SingularValueDecomposition svd = A.svd();
		Matrix sigmaInverse = svd.getS();
		for( int i=0; i<5; i++ )
		{
			sigmaInverse.set( i, i, 1.0 / sigmaInverse.get( i, i ) );
		}
		Matrix x = svd.getV().times( sigmaInverse.times( svd.getU().transpose().times( b ) ) );
		*/
		
		// build the tensor
		assert( x.getRowDimension() == 5 && x.getColumnDimension() == 1 );
		AlignmentTensor tensor = new AlignmentTensor(
			x.get( 0, 0 ),
			x.get( 1, 0 ),
			x.get( 2, 0 ),
			x.get( 3, 0 ),
			x.get( 4, 0 )
		);		
		return tensor;
	}
	
	public static List<AlignmentTensor> compute( Subunit subunit, List<Rdc<AtomAddressInternal>> rdcs, int numSamples )
	{
		return compute( new Protein( new Subunit( subunit ) ), rdcs, numSamples );
	}
	
	public static List<AlignmentTensor> compute( Protein protein, List<Rdc<AtomAddressInternal>> rdcs, int numSamples )
	{
		return compute( protein, rdcs, numSamples, new LoggingMessageListener( m_log, Level.INFO ) );
	}
	
	public static List<AlignmentTensor> compute( Protein protein, List<Rdc<AtomAddressInternal>> rdcs, int numSamples, MessageListener listener )
	{
		// make a copy of the RDCs so we can modify them
		List<Rdc<AtomAddressInternal>> sampledRdcs = Rdc.copyDeep( rdcs );
		
		// start a progresss bar if needed
		Progress progress = null;
		if( listener != null )
		{
			listener.message( "Sampling " + numSamples + " sets of RDCs..." );
			progress = new Progress( numSamples, 5000 );
			progress.setMessageListener( listener );
		}
		
		// for each sample...
		ArrayList<AlignmentTensor> tensors = new ArrayList<AlignmentTensor>();
		double maxDeviation = 0.0;
		for( int i=0; i<numSamples; i++ )
		{
			// compute and score the tensor computed from the sampled RDCs
			Rdc.sample( sampledRdcs, rdcs );
			tensors.add( compute( protein, sampledRdcs ) );
			
			// update max rdc value deviation
			for( int j=0; j<rdcs.size(); j++ )
			{
				maxDeviation = Math.max( maxDeviation, Math.abs( rdcs.get( j ).getValue() - sampledRdcs.get( j ).getValue() ) );
			}
			
			// update progress if needed
			if( progress != null )
			{
				progress.incrementProgress();
			}
		}
		
		// find the max RDC value deviation
		listener.message( "Max RDC value deviation is " + maxDeviation );
		
		return tensors;
	}
	
	
	/**************************
	 *   Methods
	 **************************/
	
	public String getStats( Subunit subunit, List<Rdc<AtomAddressInternal>> rdcs )
	{
		return getStats( new Protein( new Subunit( subunit ) ), rdcs );
	}
	
	public String getStats( Protein protein, List<Rdc<AtomAddressInternal>> rdcs )
	{
		StringBuilder buf = new StringBuilder();
		buf.append( "Computed Alignment Tensor:" );
		buf.append( "\n\tEigenvalues:" );
		buf.append( String.format( "\n\t\tDxx: %f", getDxx() ) );
		buf.append( String.format( "\n\t\tDyy: %f", getDyy() ) );
		buf.append( String.format( "\n\t\tDzz: %f", getDzz() ) );
		buf.append( "\n\tProperties:" );
		buf.append( String.format( "\n\t\tAssymmetry: %f", getAssymmetry() ) );
		buf.append( String.format( "\n\t\tRhombicity: %f", getRhombicity() ) );
		buf.append( String.format( "\n\t\tRDC RMSD: %f", getRmsd( protein, rdcs ) ) );
		return buf.toString();
	}
	
	public Matrix3 getSaupe( )
	{
		return new Matrix3(
			-m_Syy-m_Szz, m_Sxy, m_Sxz,
			m_Sxy, m_Syy, m_Syz,
			m_Sxz, m_Syz, m_Szz
		);
	}
	
	public double getRmsd( Subunit subunit, List<Rdc<AtomAddressInternal>> rdcs )
	{
		return getRmsd( new Protein( new Subunit( subunit ) ), rdcs );
	}
	
	public double getRmsd( Protein protein, List<Rdc<AtomAddressInternal>> rdcs )
	{
		double sum = 0.0;
		ArrayList<Double> values = backComputeRdcs( protein, rdcs );
		for( int i=0; i<values.size(); i++ )
		{
			double diff = Math.abs( values.get( i ) - rdcs.get( i ).getValue() );
			sum += diff * diff;
		}
		return Math.sqrt( sum / values.size() );
	}
	
	public ArrayList<Double> backComputeRdcs( Subunit subunit, List<Rdc<AtomAddressInternal>> rdcs )
	{
		return backComputeRdcs( new Protein( new Subunit( subunit ) ), rdcs );
	}
	
	public ArrayList<Double> backComputeRdcs( Protein protein, List<Rdc<AtomAddressInternal>> rdcs )
	{
		ArrayList<Double> values = new ArrayList<Double>( rdcs.size() );
		for( Rdc<AtomAddressInternal> rdc : rdcs )
		{
			Atom fromAtom = protein.getAtom( rdc.getFrom() );
			Atom toAtom = protein.getAtom( rdc.getTo() );
			
			// get the internuclear vector
			Vector3 vec = new Vector3( toAtom.getPosition() );
			vec.subtract( fromAtom.getPosition() );
			vec.normalize();
			
			double r =
				m_Sxy * 2.0 * vec.x * vec.y
				+ m_Sxz * 2.0 * vec.x * vec.z
				+ m_Syy * ( vec.y * vec.y - vec.x * vec.x )
				+ m_Syz * 2.0 * vec.y * vec.z
				+ m_Szz * ( vec.z * vec.z - vec.x * vec.x );
			values.add( r );
		}
		return values;
	}
}
