#-------------------------------------------------------------------------------
# This library is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License as published by the Free Software Foundation; either
# version 2.1 of the License, or (at your option) any later version.
# 
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
# 
# You should have received a copy of the GNU Lesser General Public
# License along with this library; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
# USA
# 
# Contact Info:
# 	Bruce Donald
# 	Duke University
# 	Department of Computer Science
# 	Levine Science Research Center (LSRC)
# 	Durham
# 	NC 27708-0129 
# 	USA
# 	brd@cs.duke.edu
# 
# Copyright (C) 2011 Jeffrey W. Martin and Bruce R. Donald
# 
# <signature of Bruce Donald>, April 2011
# Bruce Donald, Professor of Computer Science
#-------------------------------------------------------------------------------


import jpype
import os
import jvm
from sys import stderr


def onJvmStartup( jvm ):
	
	# create the factory shortcuts
	global f; f = jpype.JPackage( "edu.duke.donaldLab.share" )
	
	# init logging
	f.io.Logging.Normal.init()
	

def initJvm( jvm ):
	jvm.addStartupHandler( onJvmStartup )
	return jvm

def getJvm( ):
	return initJvm( jvm.Jvm() )


def loadProteins( file ):
	
	# read it
	print "Loading protein models from:\n\t%s" % file
	proteins = f.pdb.ProteinReader().readAll( file )
	
	# map the names
	for i in range( 0, proteins.size() ):
		f.mapping.NameMapper.ensureProtein( proteins.get( i ), f.mapping.NameScheme.New )
	
	# dump some stats
	print "\tloaded %s:" % proteins.get( 0 ).getName()
	print "\t\t%i models" % proteins.size()
	print "\t\t%i subunits" % proteins.get( 0 ).getSubunits().size()
	print "\t\t%i residues" % proteins.get( 0 ).getSubunit( 0 ).getResidues().size()

	return proteins


def loadProteinsFromDir( inDir ):
	
	proteins = jvm.f.java.util.ArrayList()
	for file in sorted( os.listdir( inDir ) ):
		print "Loading protein from:\t%s" % file
		proteins.addAll( f.pdb.ProteinReader().readAll( os.path.join( inDir, file ) ) )
		
	# map the names
	for i in range( 0, len( proteins ) ):
		f.mapping.NameMapper.ensureProtein( proteins.get( i ), f.mapping.NameScheme.New )
		
	return proteins
	

def loadProtein( file ):
	
	# read it
	print "Loading protein from:\n\t%s" % file
	protein = f.pdb.ProteinReader().read( file )
	
	# map the names
	f.mapping.NameMapper.ensureProtein( protein, f.mapping.NameScheme.New )
	
	# dump some stats
	print "\tloaded %s:" % protein.getName()
	print "\t\t%i subunits" % protein.getSubunits().size()
	print "\t\t%i residues" % protein.getSubunit( 0 ).getResidues().size()
	
	return protein


def writeProtein( file, protein ):
	
	f.pdb.ProteinWriter().write( protein, file )
	print "Wrote protein to:\n\t%s" % file


def writeProteins( file, proteins ):
	f.pdb.ProteinWriter().write( proteins, file )
	print "Wrote %d models to:\n\t%s" % ( len( proteins ), file )


def loadDistanceRestraints( file ):
	
	# read it
	print "Loading distance restraints from:\n\t%s" % file
	noes = f.nmr.DistanceRestraintReader().read( file )
	print "\tloaded %i distance restraints" % noes.size()
	
	return noes


def writeDistanceRestraints( file, restraints ):
	
	f.nmr.DistanceRestraintWriter().write( file, restraints )
	print "Wrote %d distance restraints to:\n\t%s" % ( len( restraints ), file )

	
def interpretDistanceRestraintsWithPseudoatoms( restraints, protein ):
	
	numChanged = f.pseudoatoms.PseudoatomBuilder.getInstance().buildDistanceRestraints( protein, restraints )
	print "%i distance restraints interpreted with pseudoatoms" % numChanged


def mapDistanceRestraintsToProtein( restraints, protein ):
	
	print "Mapping NOEs..."
	
	# map the names
	f.mapping.NameMapper.ensureDistanceRestraints( protein, restraints, f.mapping.NameScheme.New )
	
	# add pseudoatoms to the protein if needed
	PseudoatomBuilder = f.pseudoatoms.PseudoatomBuilder
	if PseudoatomBuilder.distanceRestraintsHavePseudoatoms( restraints ):
		print "\tBuilt %i Pseudoatoms" % PseudoatomBuilder.getInstance().build( protein )

	numRestraintsBefore = len( restraints )
	restraints = f.nmr.DistanceRestraintMapper.mapReadableToInternal( restraints, protein )
	numRestraintsAfter = len( restraints )
	
	print "\tmapped %d readable restraints to %d internal restraints" % (
		numRestraintsBefore,
		numRestraintsAfter
	)
	
	return restraints


def unmapDistanceRestraintsFromProtein( restraints, protein, collapsePseudoatoms=False ):
	
	print "Unmapping distance restraints..."
	
	numRestraintsBefore = len( restraints )
	restraints = f.nmr.DistanceRestraintMapper.mapInternalToReadable( restraints, protein, collapsePseudoatoms )
	numRestraintsAfter = len( restraints )
	
	print "\tUnmapped %d internal restraints to %d readable restraints" % (
		numRestraintsBefore,
		numRestraintsAfter
	)
	
	return restraints


def filterRestraintsIntrasubunit( restraints, subunitId=None ):
	
	Filterer = f.nmr.DistanceRestraintFilterer
	print "Filtering %i restraints..." % len( restraints )
	restraints = Filterer.pickIntrasubunit( restraints )
	print "\tIntrasubunit Filter: %i" % len( restraints )
	if subunitId is not None:
		restraints = Filterer.pickBetween( restraints, subunitId, subunitId )
		print "\tSubunit Id %d Filter: %i" % ( subunitId, len( restraints ) )
	restraints = Filterer.pickUnique( restraints )
	print "\tUnique Filter: %i" % len( restraints )
	return restraints


def filterRestraintsIntersubunit( restraints ):
	
	Filterer = f.nmr.DistanceRestraintFilterer
	
	print "Filtering %i restraints..." % len( restraints )
	restraints = Filterer.pickIntersubunit( restraints )
	print "\tIntersubunit Filter: %i" % len( restraints )
	restraints = Filterer.pickSubunitEitherSide( restraints, 0 )
	print "\tFirst Subunit Filter: %i" % len( restraints )
	restraints = Filterer.pickUnique( restraints )
	print "\tUnique Filter: %i" % len( restraints )
	restraints = Filterer.pickOneFromSymmetricGroup( restraints )
	print "\tSymmetry Filter: %i" % len( restraints )
	return restraints


def mapRestraintsToSubunit( restraints, subunitId ):
	
	return f.nmr.DistanceRestraintFilterer.mapToSubunit( restraints, subunitId )


def reportMismatchedRestraints( restraints, numSubunits ):
	
	mismatchedRestraints = f.nmr.DistanceRestraintFilterer.pickMismatchedRestraints( restraints, numSubunits )
	if len( mismatchedRestraints ) > 0:
		print >> stderr, "Warning: %i mismatched restraints!" % len( mismatchedRestraints )
		for restraint in mismatchedRestraints:
			print >> stderr, "\t", restraint


def padRestraints( restraints, padPercent ):
	for restraint in restraints:
		restraint.setMaxDistance( restraint.getMaxDistance() * ( 1.0 + padPercent ) )
		restraint.setMinDistance( restraint.getMinDistance() * ( 1.0 - padPercent ) )
	print "Padded %d restraints by %0.1f%%" % ( restraints.size(), padPercent * 100 )


def loadShifts( file ):
	
	# read it
	print "Loading Chemical Shifts from:\n\t%s" % file
	shifts = f.nmr.ChemicalShiftReader().read( file )
	print "\tloaded %i Chemical Shifts" % shifts.size()
	
	return shifts


def interpretShiftsWithPseudoatoms( shifts, protein ):
	
	numChanged = f.pseudoatoms.PseudoatomBuilder.getInstance().buildShifts( protein, shifts )
	print "%i chemical shifts interpreted with pseudoatoms" % numChanged


def mapShifts( shifts, protein ):
	
	print "Mapping Chemical Shifts..."
	
	# map the names
	f.mapping.NameMapper.ensureShifts( protein, shifts, f.mapping.NameScheme.New )
	
	mappedShifts = f.nmr.ChemicalShiftMapper.map( shifts, protein )
	print "\tmapped %i Chemical Shifts to %i Mapped Chemical Shifts" % ( shifts.size(), mappedShifts.size() )
	
	return mappedShifts


def cloneMonomer( monomer, numSubunits ):
	
	return f.protein.tools.MonomerCloner.clone( monomer, numSubunits )


def getRmsd( reference, compared ):
	
	return f.analysis.RmsdCalculator.getRmsd( reference, compared )


def getDistanceRestraintRmsd( structure, restraints ):
	
	return f.analysis.RestraintCalculator().getRmsd( structure, restraints )


def getDistanceRestraintNumSatisfied( structure, restraints ):
	
	return f.analysis.RestraintCalculator().getNumSatisfied( structure, restraints )


def  getAverageVariance( proteins ):
	
	return f.analysis.VarianceCalculator.getAverageVariance( proteins )


def  getAverageRmsd( proteins ):
	
	return f.analysis.VarianceCalculator.getAverageRmsd( proteins )

	
def printRestraintsStats( restraints ):
	
	# count ambiguous vs unambiguous
	numAmbiguous = 0
	numUnambiguous = 0
	for restraint in restraints:
		if restraint.isAmbiguous():
			numAmbiguous += 1
		else:
			numUnambiguous += 1
	
	print "Restraints stats:"
	print "\ttotal: %i" % len( restraints )
	print "\tunambiguous: %i" % numUnambiguous
	print "\tambiguous %i" % numAmbiguous
	

def alignOptimally( reference, compared ):
	
	f.protein.tools.ProteinGeometry.center( reference )
	f.analysis.StructureAligner.alignOptimally( reference, compared )
	
	
def alignOptimallyBySubunit( reference, compared, subunitId ):
	
	translation = f.protein.tools.ProteinGeometry.getCentroid( reference.getSubunit( subunitId ) )
	translation.negate()
	f.protein.tools.ProteinGeometry.translate( reference, translation )
	f.analysis.StructureAligner.alignOptimallyBySubunit( reference, compared, subunitId )


def alignEnsembleOptimally( reference, ensemble ):
	
	f.protein.tools.ProteinGeometry.center( reference )
	for structure in ensemble:
		f.analysis.StructureAligner.alignOptimally( reference, structure )


def alignEnsembleOptimallyByAtoms( reference, ensemble, addresses ):
	
	f.protein.tools.ProteinGeometry.center( reference )
	for structure in ensemble:
		f.analysis.StructureAligner.alignOptimallyByAtoms( reference, structure, addresses, addresses )


def alignEnsembleOptimallyBySubunit( reference, ensemble, subunitId ):
	
	translation = f.protein.tools.ProteinGeometry.getCentroid( reference.getSubunit( subunitId ) )
	translation.negate()
	f.protein.tools.ProteinGeometry.translate( reference, translation )
	for structure in ensemble:
		f.analysis.StructureAligner.alignOptimallyBySubunit( reference, structure, subunitId )


def saveEnsembleKinemage( pathOut, ensemble, reference=None ):
	
	kinemage = f.kinemage.Kinemage( "Ensemble" )
	KinemageBuilder = f.kinemage.KinemageBuilder
	KinemageBuilder.appendAxes( kinemage )
	
	# add the reference structure
	if reference != None:
		KinemageBuilder.appendBackbone( kinemage, reference, "Reference", 1, 3 )
	
	# add the aligned ensemble structures
	i = 1
	for structure in ensemble:
		KinemageBuilder.appendBackbone( kinemage, structure, "Ensemble %d" % i, 0, 2 )
		i += 1
	
	# write it!
	f.kinemage.KinemageWriter().write( kinemage, pathOut )


def savePairKinemage( pathOut, reference, computed ):
	
	# build the kinemage
	kinemage = f.kinemage.Kinemage( "Pair" )
	KinemageBuilder = f.kinemage.KinemageBuilder
	KinemageBuilder.appendAxes( kinemage )
	KinemageBuilder.appendBackbone( kinemage, reference, "Reference", 1, 3 )
	KinemageBuilder.appendBackbone( kinemage, computed, "Computed", 0, 2 )
	
	# write it!
	f.kinemage.KinemageWriter().write( kinemage, pathOut )
	print "Kinemage with 2 structures written to:\n\t%s" % pathOut


def saveStructureKinemage( pathOut, protein, restraints=None ):
	
	# build the kinemage
	kinemage = f.kinemage.Kinemage( "Pair" )
	KinemageBuilder = f.kinemage.KinemageBuilder
	KinemageBuilder.appendAxes( kinemage )
	KinemageBuilder.appendProtein( kinemage, protein )
	
	if restraints != None:
		KinemageBuilder.appendDistanceRestraints( kinemage, protein, restraints )
	
	# write it!
	f.kinemage.KinemageWriter().write( kinemage, pathOut )
	print "Kinemage with 1 structure written to:\n\t%s" % pathOut


def runKing( pathKinemage ):
	
	os.system( "king \"%s\"" % pathKinemage )
	

def minimize( protein, distanceRestraints=None, logPath=None ):
	minimizer = f.xplor.StructureMinimizer()
	if distanceRestraints != None:
		minimizer.setDistanceRestraints( distanceRestraints )
	if logPath != None:
		minimizer.setLog( logPath )
	return minimizer.minimize( protein )


def calculateEnergy( protein, logPath=None ):
	
	calculator = f.xplor.EnergyCalculator()
	if logPath != None:
		calculator.setLog( logPath )
	return calculator.getEnergy( protein )


def newSubunitOrder( subunitOrder ):
	
	return f.protein.SubunitOrder( subunitOrder )


def newAlignmentTensor( protein, rdcs, numSamples=0 ):
	
	tensor = f.nmr.AlignmentTensor.compute( protein, rdcs )
	print tensor.getStats( protein, rdcs )
	return tensor


def newAlignmentTensors( protein, rdcs, numSamples ):
	
	return f.nmr.AlignmentTensor.compute( protein, rdcs, numSamples )

