import java.util.*;
import java.io.Serializable;

public class TreeEdge implements Serializable {

	private TreeNode p = null; //the parent node adjacent to the edge in the rooted tree
	private TreeNode c = null; //the child node adjacent to the edge in the rooted tree
	
	private boolean isLambdaEdge = false; //determines if this is a lambda edge
	
	private Set<Integer> M = null; //the M set (vertices have molecule index-relative numbering)
	private Set<Integer> L = null; //the L set (vertices have molecule index-relative numbering)
	private Set<Integer> lambda = null; //the lambda set (vertices have molecule index-relative numbering)
	
	private int molResMap[] = null; //maps redesign-order indices to molecule-relative residue numbers
	private int invResMap[] = null; //maps the indices of molecule-relative residue numbers to redesign-order indexing (used in the energy matrices); ligand is mapped last
	
	private int A[][] = null; //the A matrix for the current tree edge
	private float energy[] = null; //the computed energy for each state combination in A
	
	private int sysStrNum = -1; //the strand number in the molecule of the system strand
	
	private int numStates[] = null; //the number of allowed states (unpruned rotamers) for each redesign position
	
	private RotTypeMap rtm[][] = null; //the mapping between state indices and residue/aa/rot tuples
	
	int numInAS = -1; //total number of redesign residues in the system strand only (ligand if present is in addition to this)
	
	public TreeEdge(int numUnprunedRot[], int molResidueMap[], int invResidueMap[], int sysStrandNum, int numResInAS){
		
		numStates = numUnprunedRot;
		molResMap = molResidueMap;
		invResMap = invResidueMap;
		sysStrNum = sysStrandNum;
		numInAS = numResInAS;
		
		/*
		 * M, L, and lambda should use molecule-relative residue numbering, and not the input pdb-relative numbering
		 * 
		 */
		
		computeLambda();
		
		if (isLambdaEdge){ //this is a lambda edge
			
			initializeMatrices();
			
			rtm = new RotTypeMap[M.size()+lambda.size()][];
		}
	}
	
	//Computes the lambda set for the current tree edge
	private void computeLambda(){
		
		TreeNode clc = c.getlc();
		if (clc!=null) { //internal tree node, so both children exist
			Set<Integer> u = new LinkedHashSet<Integer>(clc.getCofEdge().getL());
			u.addAll(c.getrc().getCofEdge().getL()); //the union of the lambda's for the two adjacent edges with the two children
			lambda = new LinkedHashSet<Integer>(L);
			lambda.removeAll(u);
		}
		
		else //leaf tree node
			lambda = new LinkedHashSet<Integer>(L);
		
		if (!lambda.isEmpty()) //lambda is non-empty
			isLambdaEdge = true;
		else
			isLambdaEdge = false;
	}
	
	//Initialize the A[] and energy[] matrices;
	//		Each entry in the first dimension of A corresponds to a unique state assignment in M; for each such entry,
	//			the second dimension gives the best state for each graph vertex in lambda
	private void initializeMatrices(){
		
		int size = 1;
		for (int i=0; i<numStates.length; i++){
			if (M.contains(molResMap[i])){
				size *= numStates[i];
			}
		}
		
		A = new int[size][lambda.size()];
		energy = new float[size];
	}
	
	public Set<Integer> getL(){
		return L;
	}
	
	public Set<Integer> getM(){
		return M;
	}
	
	public boolean getIsLambdaEdge(){
		return isLambdaEdge;
	}
	
	public Set<Integer> getLambda(){
		return lambda;
	}
	
	public int [][] getA(){
		return A;
	}
	
	public TreeNode getc(){
		return c;
	}
	
	//Computes and stores the A matrix for the current edge
	public void computeA(StrandRotamers sysLR, StrandRotamers ligRot, Molecule m, RotamerLibrary rl, RotamerLibrary grl, 
			boolean prunedRot[], int numTotalRot, int rotIndOffset[], float eMatrix[][][][][][], InteractionGraph G){
		
		int maxDepth = M.size() + lambda.size();
		
		Object arrayM[] = M.toArray();
		Object arrayLambda[] = lambda.toArray();
		
		int curState[] = new int[maxDepth];
		int bestState[] = new int[maxDepth];
		for (int i=0; i<maxDepth; i++){
			curState[i] = -1;
			bestState[i] = -1;
		}
		
		float bestEnergy[] = new float[]{(float)Math.pow(10,38)};
		
		boolean updateA = false;
		
		computeAhelper(0, maxDepth, arrayM, arrayLambda, sysLR, ligRot, m, rl, grl, prunedRot, numTotalRot, rotIndOffset,
				curState, updateA, eMatrix, G, bestState, bestEnergy);
	}
	
	//Called by computeA(.)
	private void computeAhelper(int depth, int maxDepth, Object arrayM[], Object arrayLambda[], StrandRotamers sysLR, StrandRotamers ligRot,
			Molecule m, RotamerLibrary rl, RotamerLibrary grl, boolean prunedRot[], int numTotalRot, int rotIndOffset[], 
			int curState[], boolean updateA, float eMatrix[][][][][][], InteractionGraph G, int bestState[], float bestEnergy[]){		
		
		if (depth >= maxDepth){ //end level of recursive calls; call the backtracking procedure to look-up the optimal states for (L-lambda)
			
			Set<Integer> ll = new LinkedHashSet<Integer>(L); //the set difference (L-lambda) for this edge
			ll.removeAll(lambda);
			int curStateLL[] = new int[ll.size()];
			TreeEdge curStateLLedges[] = new TreeEdge[ll.size()];
			Object arrayLL[] = ll.toArray();
			
			bTrack(curStateLLedges, curStateLL, arrayLL, curState);
			
			float en = computeEforState(curState,curStateLL,curStateLLedges,eMatrix,m,G);
			
			if (en<bestEnergy[0]) { //new best energy, so update to the current state assignment
				bestEnergy[0] = en;
				System.arraycopy(curState, 0, bestState, 0, curState.length);
			}
		}
		
		else { //setup current level
			
			Object vArray[] = null;
			if (depth < M.size()){ //work on the M set
				vArray = arrayM;
				updateA = false;
			}
			else { //work on the lambda set
				vArray = arrayLambda;
				updateA = true;
			}
			
			rtm[depth] = new RotTypeMap[numStates[invResMap[(Integer)vArray[depth]]]];
			
			int curPos = invResMap[(Integer)vArray[depth]];
			
			int curStrandResNum = m.residue[(Integer)vArray[depth]].strandResidueNumber;
			
			StrandRotamers str = null;
			RotamerLibrary rotLib = null;
			if (m.residue[(Integer)vArray[depth]].strandNumber==sysStrNum){ //this residue is in the system strand
				str = sysLR;
				rotLib = rl;
			}
			else { //this residue is in the ligand strand
				str = ligRot;
				rotLib = grl;
			}
			
			for(int q=0;q<str.getNumAllowable(curStrandResNum);q++) { //for all allowed amino acid types
				
				int AAindex = str.getIndexOfNthAllowable(curStrandResNum,q);
				
				int numRot = rotLib.getNumRotForAAtype(AAindex);
				if (numRot==0)
					numRot = 1;
				
				for(int w=0;w<numRot;w++) { //for all rotamers
					
					if (!prunedRot[curPos*numTotalRot + rotIndOffset[AAindex] + w]){ //rotamer not pruned, so check
						
						curState[depth]++;
						
						if (rtm[depth][curState[depth]]==null) //store the mapping from state index to residue/aa/rot tuple
							rtm[depth][curState[depth]] = new RotTypeMap(curPos,AAindex,w);
						
						computeAhelper(depth+1, maxDepth, arrayM, arrayLambda, sysLR, ligRot, m, rl, grl, prunedRot, numTotalRot, rotIndOffset,
								curState, updateA, eMatrix, G, bestState, bestEnergy);
					}	
				}
			}
			
			curState[depth] = -1;
			
			if ( (depth==M.size()) && (updateA) ){//done with the lambda states for the current state assignment in M, so update A[] and energy[]
				
				storeBestStateLambda(bestState, arrayM); //store the best state for each vertex in lambda, for the current state assignment in M
				
				bestEnergy[0] = (float)Math.pow(10, 38);
				updateA = false;
			}
		}
	}
	
	//Backtracking procedure for the matrix computation (for a given state assignment in (M u lambda), looks up the optimal states for (L-lambda) )
	private void bTrack(TreeEdge curStateLLedges[], int curStateLL[], Object topArrayLL[], int curState[]){
		                          
		bTrackHelper(this, curStateLLedges, curStateLL, topArrayLL, curState);
	}
	
	//Called by bTrack(.)
	private void bTrackHelper(TreeEdge e, TreeEdge curStateLLedges[], int curStateLL[], Object topArrayLL[], int curState[]){
		
		Set<Integer> ll = new LinkedHashSet<Integer>(e.L); //the set difference (L-lambda) for the current edge
		ll.removeAll(e.lambda);
		Object arrayLL[] = ll.toArray();
		
		Set<TreeEdge> Fi = new LinkedHashSet<TreeEdge>();
		
		boolean lookedUp[] = new boolean[arrayLL.length];		
		Object arrayFi[][] = new Object[lookedUp.length][];
		
		for (int i=0; i<lookedUp.length; i++){ //all graph vertices in (L-lambda)
			
			lookedUp[i] = false;
			
			bTrackHelperFi(e.getc(),Fi);
			arrayFi[i] = Fi.toArray();
			
			for (int j=0; j<arrayFi[i].length; j++){ //all tree edges in Fi
				
				TreeEdge fk = (TreeEdge)arrayFi[i][j];
				
				if (fk!=null) {
					
					if (fk.getLambda().contains(arrayLL[i])){ //perform computation only if arrayLL[i] is in the L set of edge fk
						
						lookedUp[i] = true;
						for (int k=0; k<topArrayLL.length; k++){
							if (topArrayLL[k]==arrayLL[i]){
								curStateLL[k] = getStateForV(curState,(Integer)arrayLL[i],fk);
								curStateLLedges[k] = fk;
								break;
							}
						}
					}
				}
			}
		}
		
		//This second iteration is necessary (and not part of the first iteration) to ensure that all graph vertices in lambda(fk) for any fk in Fi have been looked up
		for (int i=0; i<lookedUp.length; i++){ //all graph vertices in (L-lambda)
			
			if (!lookedUp[i]) { //optimal state for the current graph vertex not found yet
				
				for (int j=0; j<arrayFi[i].length; j++){ //all tree edges in Fi
					
					TreeEdge fk = (TreeEdge)arrayFi[i][j];
					
					if (fk!=null) {
						
						if ( (fk.getL().contains(arrayLL[i])) && (!fk.getLambda().contains(arrayLL[i])) ){ //arrayLL[i] is in the (L-lambda) set of fk
							
							bTrackHelper(fk, curStateLLedges, curStateLL, topArrayLL, curState); //recursively call the backtracking procedure
						}
					}
				}
			}
		}
	}
	
	//Called by bTrackHelper(.); finds the tree edges belonging to the set Fi starting at the sub-tree rooted at the tree node tn
	private void bTrackHelperFi(TreeNode tn, Set<TreeEdge> Fi){
		
		TreeNode clc = tn.getlc();
		TreeNode crc = tn.getrc();
		
		if (clc!=null){
			TreeEdge clce = clc.getCofEdge();
			if (clce.getIsLambdaEdge()) //if clce is a lambda edge, add it to Fi, and do not travers the subtree rooted at clc
				Fi.add(clce);
			else //not a lambda edge, so traverse the subtree rooted at clc
				bTrackHelperFi(clc,Fi);				
		}
		
		if (crc!=null){
			TreeEdge crce = crc.getCofEdge();
			if (crce.getIsLambdaEdge()) 
				Fi.add(crce);
			else
				bTrackHelperFi(crc,Fi);			
		}
	}
	
	//Given the state assignments in curState[] (for this edge), get the state assignment for the graph vertex v from the tree edge e
	private int getStateForV(int curState[], int v, TreeEdge e){
		
		Object curM[] = M.toArray();
		Object curLambda[] = lambda.toArray();
		Object MLambda[] = new Object[curM.length+curLambda.length];
		System.arraycopy(curM, 0, MLambda, 0, curM.length);
		System.arraycopy(curLambda,0,MLambda,curM.length,curLambda.length);
		curM = null; curLambda = null;
		
		Object eM[] = e.getM().toArray();
		
		int ind[] = new int[eM.length];
		
		for (int i=0; i<eM.length; i++){ //find the state for each vertex in e.M (it must be either in this.M or this.lambda)
			for (int j=0; j<MLambda.length; j++){
				
				if (eM[i]==MLambda[j]){
					
					int p = rtm[j][curState[j]].pos;
					int a = rtm[j][curState[j]].aa;
					int r = rtm[j][curState[j]].rot;
					
					RotTypeMap ertm[] = e.getrtm()[i];
					for (int k=0; k<ertm.length; k++){
						if ( (ertm[k].pos==p) && (ertm[k].aa==a) && (ertm[k].rot==r) ){ //found state for vertex eM[i]
							ind[i] = k;
							break;
						}
					}
					break;
				}
			}
		}
		
		int vInd = -1;
		Object eLambda[] = e.getLambda().toArray();
		for (int i=0; i<eLambda.length; i++){
			if ((Integer)eLambda[i]==v){ //graph vertex found in e.lambda
				vInd = i;
				break;
			}
		}
		
		return e.getA()[e.computeIndexInA(ind,eM)][vInd]; //return the best state for graph vertex v using the information in e.A
	}
	
	//Compute the (partial) energy for the given state assignment;
	//	Includes the shell, intra, pairwise, and res-to-template energies of the residue positions in (M u L) for this edge;
	//	Only include pairwise interactions if an edge is present in the residue interaction graph
	private float computeEforState(int curState[], int curStateLL[], TreeEdge curStateLLedges[], float eMatrix[][][][][][], 
			Molecule m, InteractionGraph G){
		
		float en = eMatrix[eMatrix.length-1][0][0][0][0][0]; // Add shell-shell energy
		
		int numPos = M.size() + L.size();
		
		for (int i=0; i<numPos; i++){
			
			int curIndi = -1;
			int pi = -1;
			int ai = -1;
			int ri = -1;
			
			if (i<(M.size()+lambda.size())) { //vertex is in (M or lambda)
				curIndi = i;
				pi = rtm[curIndi][curState[curIndi]].pos;
				ai = rtm[curIndi][curState[curIndi]].aa;
				ri = rtm[curIndi][curState[curIndi]].rot;
			}
			else { //vertex is in (L-lambda)
				curIndi = i - (M.size()+lambda.size());
				pi = curStateLLedges[curIndi].getrtm()[curIndi][curStateLL[curIndi]].pos;
				ai = curStateLLedges[curIndi].getrtm()[curIndi][curStateLL[curIndi]].aa;
				ri = curStateLLedges[curIndi].getrtm()[curIndi][curStateLL[curIndi]].rot;
			}
			
			en += eMatrix[pi][ai][ri][pi][0][1]; // Add the rotamer-shell energy			
			en += eMatrix[pi][ai][ri][pi][0][0]; // Add the intra-rotamer energy	
			
			for (int j=i+1; j<numPos; j++){
				
				int curIndj = -1;
				int pj = -1;
				int aj = -1;
				int rj = -1;
				
				if (j<(M.size()+lambda.size())) { //vertex is in (M or lambda)
					curIndj = j;
					pj = rtm[curIndj][curState[curIndj]].pos;
					aj = rtm[curIndj][curState[curIndj]].aa;
					rj = rtm[curIndj][curState[curIndj]].rot;
				}
				else { //vertex is in (L-lambda)
					curIndj = j - (M.size()+lambda.size());
					pj = curStateLLedges[curIndj].getrtm()[curIndj][curStateLL[curIndj]].pos;
					aj = curStateLLedges[curIndj].getrtm()[curIndj][curStateLL[curIndj]].aa;
					rj = curStateLLedges[curIndj].getrtm()[curIndj][curStateLL[curIndj]].rot;
				}
				
				if (G.edgeExists(pi, pj)) { //the two residues interact in the interaction graph, so add their pairwise energy
					en += eMatrix[pi][ai][ri][pj][aj][rj];
				}
			}
		}

		return en;
	}
	
	//Compute the best state for each vertex in lambda, for the given state assignment in M
	private void storeBestStateLambda(int curState[], Object arrayM[]) {
		
		int curIndInA = computeIndexInA(curState, arrayM); //get the index corresponding to the current state assignment in M
		
		for (int i=0; i<lambda.size(); i++){
			A[curIndInA][i] = curState[arrayM.length+i];
		}
	}
	
	//Compute the index into the A matrix for this tree edge, given the state assignments for the vertices in M in curState[]
	public int computeIndexInA(int curState[], Object arrayM[]){
		
		int index = 0;
		int s = 1;
		
		for (int i=(arrayM.length-1); i>=0; i--){ //find the state assignment for the vertices in M
			
			s *= (i * numStates[invResMap[(Integer)arrayM[i]]]);
			index += (curState[i] * s);
		}
		
		return index;
	}
	
	//Maps state indices to residue/aa/rot tuples
	private class RotTypeMap {
		int pos = -1;
		int aa = -1;
		int rot = -1;
		RotTypeMap(int p, int a, int r){
			pos = p;
			aa = a;
			rot = r;
		}
	}
	
	public RotTypeMap[][] getrtm(){
		return rtm;
	}
}
