/*******************************************************************************
 * This library is free software; you can redistribute it and/or
 * modify it under the terms of the GNU Lesser General Public
 * License as published by the Free Software Foundation; either
 * version 2.1 of the License, or (at your option) any later version.
 * 
 * This library is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
 * GNU Lesser General Public License for more details.
 * 
 * You should have received a copy of the GNU Lesser General Public
 * License along with this library; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
 * USA
 * 
 * Contact Info:
 * 	Bruce Donald
 * 	Duke University
 * 	Department of Computer Science
 * 	Levine Science Research Center (LSRC)
 * 	Durham
 * 	NC 27708-0129 
 * 	USA
 * 	brd@cs.duke.edu
 * 
 * Copyright (C) 2011 Jeffrey W. Martin and Bruce R. Donald
 * 
 * <signature of Bruce Donald>, April 2011
 * Bruce Donald, Professor of Computer Science
 ******************************************************************************/


package edu.duke.donaldLab.jdshot.disco;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;

import org.apache.log4j.Logger;

import Jama.EigenvalueDecomposition;
import Jama.Matrix;

import edu.duke.donaldLab.share.geom.Vector3;
import edu.duke.donaldLab.share.io.Logging;
import edu.duke.donaldLab.share.math.CompareReal;
import edu.duke.donaldLab.share.math.Matrix3;
import edu.duke.donaldLab.share.math.PointIteratorDelta;
import edu.duke.donaldLab.share.math.Quaternion;
import edu.duke.donaldLab.share.nmr.AlignmentTensor;
import edu.duke.donaldLab.share.nmr.AlignmentTensorAxis;
import edu.duke.donaldLab.share.protein.Subunit;
import edu.duke.donaldLab.share.protein.tools.ProteinGeometry;

public class OrientationCalculator
{
	/**************************
	 *   Definitions
	 **************************/
	
	private static final double Epsilon = 1e-2;
	
	
	/**************************
	 *   Data Members
	 **************************/
	
	private static final Logger m_log = Logging.getLog( OrientationCalculator.class );
	
	
	/**************************
	 *   Static Methods
	 **************************/
	
	public static ArrayList<Vector3> computeOrientations( int numSubunits, AlignmentTensor tensor )
	{
		if( numSubunits == 2 )
		{
			return computeOrientationsDimer( tensor );
		}
		else if( numSubunits > 2 )
		{
			return computeOrientationsHigherOligomer( tensor );
		}
		else
		{
			throw new IllegalArgumentException( "numSubunits must be 2 or greater, but instead is " + numSubunits + "." );
		}
	}
	
	public static ArrayList<Vector3> computeOrientationsDimer( AlignmentTensor tensor )
	{
		// all eigenvalues must be distinct
		if( CompareReal.eq( tensor.getDxx(), tensor.getDyy(), Epsilon ) )
		{
			m_log.warn( String.format( "Alignment Tensor Dxx (%f) and Dyy (%f) are too similar: ", tensor.getDxx(), tensor.getDyy() ) );
		}
		if( CompareReal.eq( tensor.getDxx(), tensor.getDzz(), Epsilon ) )
		{
			m_log.warn( String.format( "Alignment Tensor Dxx (%f) and Dzz (%f) are too similar: ", tensor.getDxx(), tensor.getDzz() ) );
		}
		if( CompareReal.eq( tensor.getDyy(), tensor.getDzz(), Epsilon ) )
		{
			m_log.warn( String.format( "Alignment Tensor Dyy (%f) and Dzz (%f) are too similar: ", tensor.getDyy(), tensor.getDzz() ) );
		}
		
		// simply return all the eigenvectors
		// NOTE: we don't need to check both flips since our rotation is ultimately 180 deg
		ArrayList<Vector3> vectors = new ArrayList<Vector3>( 3 );
		vectors.add( new Vector3( tensor.getXAxis() ) );
		vectors.add( new Vector3( tensor.getYAxis() ) );
		vectors.add( new Vector3( tensor.getZAxis() ) );
		
		// HACKHACK: flip all three vectors so they use the same conventions as previous software
		for( Vector3 vector : vectors )
		{
			vector.negate();
		}
		
		return vectors;
	}
	
	public static ArrayList<Vector3> computeOrientationsHigherOligomer( AlignmentTensor tensor )
	{
		// make sure Dxx = Dyy (approximately)
		if( CompareReal.neq( tensor.getDxx(), tensor.getDyy(), 1.0 ) )
		{
			m_log.warn( String.format( "Alignment Tensor Dxx (%f) and Dyy (%f) are not very similar: ", tensor.getDxx(), tensor.getDyy() ) );
		}
		
		// return the z axis of the tensor
		Vector3 axis = new Vector3( tensor.getZAxis() );
		ArrayList<Vector3> vectors = new ArrayList<Vector3>( 1 );
		vectors.add( axis );
		return vectors;
	}
	
	public static ArrayList<AlignmentTensor> pickTensorCluster( List<AlignmentTensor> tensors, AlignmentTensor bestTensor, AlignmentTensorAxis axis )
	{
		return pickTensorCluster( tensors, bestTensor.getAxis( axis ), axis );
	}
	
	public static ArrayList<AlignmentTensor> pickTensorCluster( List<AlignmentTensor> tensors, Vector3 bestAxis, AlignmentTensorAxis axis )
	{
		// convert to a list of deviation angles
		ArrayList<AlignmentTensor> pickedTensors = new ArrayList<AlignmentTensor>();
		for( AlignmentTensor tensor : tensors )
		{
			double angle = Math.acos( Math.min( bestAxis.getDot( tensor.getAxis( axis ) ), 1.0 ) );
			
			// if axes are within 180 deg, they're part of our cluster
			if( CompareReal.lte( angle, Math.PI / 2.0 ) )
			{
				pickedTensors.add( tensor );
			}
		}
		return pickedTensors;
	}
	
	public static double getMaxAngleDeviationRadians( List<AlignmentTensor> tensors, AlignmentTensor bestTensor, AlignmentTensorAxis axis )
	{
		// convert to a list of deviation angles
		Vector3 bestAxis = bestTensor.getAxis( axis );
		double maxAngleRadians = 0.0;
		for( int i=0; i<tensors.size(); i++ )
		{
			maxAngleRadians = Math.max( maxAngleRadians, Math.acos( Math.min( bestAxis.getDot( tensors.get( i ).getAxis( axis ) ), 1.0 ) ) );
		}
		return maxAngleRadians;
	}
	
	public static double getCircularConeByPercentile( List<AlignmentTensor> tensors, AlignmentTensor bestTensor, AlignmentTensorAxis axis, double percentile )
	{
		// convert to a list of deviation angles
		Vector3 bestAxis = bestTensor.getAxis( axis );
		double[] deviations = new double[tensors.size()];
		for( int i=0; i<tensors.size(); i++ )
		{
			deviations[i] = Math.acos( Math.min( bestAxis.getDot( tensors.get( i ).getAxis( axis ) ), 1.0 ) );
		}
		Arrays.sort( deviations );
		
		// compute the percentile
		return deviations[ (int)( (double)( deviations.length - 1 ) * percentile ) ];
	}
	
	public static EllipticalCone getEllipticalConeByPercentile( List<AlignmentTensor> tensors, AlignmentTensor bestTensor, AlignmentTensorAxis axis, double percentile )
	{
		// find the two major axis directions for the ellipse using PCA
		Vector3 centroid = new Vector3( 0.0, 0.0, 0.0 );
		for( int i=0; i<tensors.size(); i++ )
		{
			centroid.add( tensors.get( i ).getAxis( axis ) );
		}
		centroid.scale( 1.0 / (double)tensors.size() );
		Matrix m = new Matrix( 3, tensors.size() );
		for( int i=0; i<tensors.size(); i++ )
		{
			Vector3 pos = tensors.get( i ).getAxis( axis );
			m.set( 0, i, pos.x - centroid.x );
			m.set( 1, i, pos.y - centroid.y );
			m.set( 2, i, pos.z - centroid.z );
		}
		Matrix cor = m.times( m.transpose() );
		final EigenvalueDecomposition eig = new EigenvalueDecomposition( cor );
		
		// sort eigenvectors by eigenvalue magnitude
		List<Integer> indices = Arrays.asList( 0, 1, 2 );
		Collections.sort( indices, new Comparator<Integer>( )
		{
			@Override
			public int compare( Integer a, Integer b )
			{
				// the index is first if its eigenvalue has a larger magnitude
				return Double.compare( Math.abs( eig.getD().get( b, b ) ), Math.abs( eig.getD().get( a, a ) ) );
			}
		} );
		
		// the two eigenvectors with the largest eigenvalue magnitudes are the elliptical axes
		// (but project them onto the plane normal to the central axis)
		EllipticalCone cone = new EllipticalCone();
		cone.centralAxis = bestTensor.getAxis( axis );
		cone.xAxis = new Vector3(
			eig.getV().get( 0, indices.get( 0 ) ),
			eig.getV().get( 1, indices.get( 0 ) ),
			eig.getV().get( 2, indices.get( 0 ) )
		);
		cone.xAxis.orthogonalProjection( cone.centralAxis );
		cone.xAxis.normalize();
		cone.yAxis = new Vector3(
			eig.getV().get( 0, indices.get( 1 ) ),
			eig.getV().get( 1, indices.get( 1 ) ),
			eig.getV().get( 2, indices.get( 1 ) )
		);
		cone.yAxis.orthogonalProjection( cone.centralAxis );
		cone.yAxis.normalize();
		
		// project all the sampled axes onto the x- and y-axis planes of the cone
		// then measure their angles to the best axis
		cone.xMaxAngleRadians = getAngleRadiansInDirectionByPercentile( tensors, axis, cone.centralAxis, cone.xAxis, percentile );
		cone.yMaxAngleRadians = getAngleRadiansInDirectionByPercentile( tensors, axis, cone.centralAxis, cone.yAxis, percentile );
		
		return cone;
	}
	
	public static ArrayList<Vector3> computeOrientationsInRange( Vector3 axis, double maxAngleRadians, double resolutionDegrees )
	{
		// UNDONE: this code is a little sketchy and probably shouldn't be used
		EllipticalCone cone = new EllipticalCone();
		cone.centralAxis = axis;
		cone.xAxis = new Vector3( 1.0, 0.0, 0.0 );
		cone.xAxis.orthogonalProjection( axis );
		cone.xAxis.normalize();
		cone.yAxis = new Vector3( 0.0, 1.0, 0.0 );
		cone.yAxis.orthogonalProjection( axis );
		cone.yAxis.normalize();
		cone.xMaxAngleRadians = maxAngleRadians;
		cone.yMaxAngleRadians = maxAngleRadians;
		return computeOrientationsInRange( cone, resolutionDegrees );
	}
	
	public static ArrayList<Vector3> computeOrientationsInRange( EllipticalCone cone, double resolutionDegrees )
	{
		// sample orientations on the unit sphere around the best axis
		/* NOTE:
			Here, we're sampling (theta,phi) space uniformly within a elliptical disc centered at (0,0).
			We also hope the distortion due to singularities isn't too large! =)
		*/
		ArrayList<Vector3> orientations = new ArrayList<Vector3>();
		double maxAngle = Math.max( cone.xMaxAngleRadians, cone.yMaxAngleRadians );
		PointIteratorDelta iter = new PointIteratorDelta(
			2,
			new double[] { 0.0, 0.0 },
			new double[] { Math.toRadians( resolutionDegrees ), Math.toRadians( resolutionDegrees ) },
			new double[] { maxAngle, maxAngle }
		);
		Quaternion rotation = new Quaternion();
		Quaternion.getRotation( rotation, Vector3.getUnitX(), cone.centralAxis );
		while( iter.hasNext() )
		{
			// build the sampled vector
			Vector3 orientation = new Vector3();
			orientation.fromAngles( iter.next() );
			
			// rotate vector to near the best axis
			rotation.rotate( orientation );
			
			if( cone.contains( orientation ) )
			{
				orientations.add( orientation );
			}
		}
		return orientations;
	}
	
	public static void normalizeProtein( Subunit subunit, Vector3 computedOrientation )
	{
		// rotate the protein so the symmetry axis is parallel to the z-axis
		Quaternion rotation = new Quaternion();
		Quaternion.getRotation( rotation, computedOrientation, Vector3.getUnitZ() );
		ProteinGeometry.rotate( subunit, rotation );
		
		// center the protein at the origin
		ProteinGeometry.center( subunit );
	}
	
	public static void normalizeProteinAndAxis( Subunit subunit, Vector3 position, Vector3 computedOrientation )
	{
		// center the protein in the coordinate system
		Vector3 centroid = ProteinGeometry.getCentroid( subunit );
		centroid.negate();
		ProteinGeometry.translate( subunit, centroid );
		
		// rotate the protein so the symmetry axis is parallel to the z-axis
		Quaternion rotation = new Quaternion();
		Quaternion.getRotation( rotation, computedOrientation, Vector3.getUnitZ() );
		ProteinGeometry.rotate( subunit, rotation );
		
		// translate and rotate the axis
		position.add( centroid );
		rotation.rotate( position );
	}
	
	
	/**************************
	 *   Static Functions
	 **************************/
	
	private static double getAngleRadiansInDirectionByPercentile( List<AlignmentTensor> tensors, AlignmentTensorAxis axis, Vector3 centralAxis, Vector3 direction, double percentile )
	{
		// construct a right-handed orthonormal basis with x=centralAxis and y=direction
		Vector3 orthogonalDirection = new Vector3();
		centralAxis.getCross( orthogonalDirection, direction );
		Matrix3 basis = new Matrix3();
		basis.setColumns( centralAxis, direction, orthogonalDirection );
		
		// get the inverse of the basis too
		Matrix3 basisInv = new Matrix3( basis );
		basisInv.transpose();
		
		// compute the angles of all the tensor axes
		double[] deviations = new double[tensors.size()];
		Vector3 a = new Vector3();
		for( int i=0; i<tensors.size(); i++ )
		{
			// rotate axis out of the basis
			a.set( tensors.get( i ).getAxis( axis ) );
			basisInv.multiply( a );
			
			// project to xy plane and renormalize
			a.z = 0.0;
			a.normalize();
			
			// rotate back into the basis
			basis.multiply( a );
			deviations[i] = Math.acos( Math.min( a.getDot( centralAxis ), 1.0 ) );
		}
		Arrays.sort( deviations );
		
		// compute the percentile
		return deviations[ (int)( (double)( deviations.length - 1 ) * percentile ) ];
	}
}
