#-------------------------------------------------------------------------------
# This library is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License as published by the Free Software Foundation; either
# version 2.1 of the License, or (at your option) any later version.
# 
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
# 
# You should have received a copy of the GNU Lesser General Public
# License along with this library; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
# USA
# 
# Contact Info:
# 	Bruce Donald
# 	Duke University
# 	Department of Computer Science
# 	Levine Science Research Center (LSRC)
# 	Durham
# 	NC 27708-0129 
# 	USA
# 	brd@cs.duke.edu
# 
# Copyright (C) 2011 Jeffrey W. Martin and Bruce R. Donald
# 
# <signature of Bruce Donald>, April 2011
# Bruce Donald, Professor of Computer Science
#-------------------------------------------------------------------------------


import jpype, os
import jvm
from sys import stderr


def onJvmStartup( jvm ):
	
	# create the factory shortcuts
	global f; f = jpype.JPackage( "libprotnmr" )
	
	# init logging
	f.io.Logging.Normal.init()
	

def initJvm( jvm ):
	jvm.addStartupHandler( onJvmStartup )
	return jvm

def getJvm( ):
	return initJvm( jvm.Jvm() )


def loadProteins( file ):
	
	# read it
	print "Loading protein models from:\n\t%s" % file
	proteins = f.pdb.ProteinReader().readAll( file )
	
	# map the names
	firstProtein = None
	for protein in proteins:
		if protein is not None:
			f.mapping.NameMapper.ensureProtein( protein, f.mapping.NameScheme.New )
			if firstProtein is None:
				firstProtein = protein
	
	# dump some stats
	if firstProtein is not None:
		print "\tloaded %s:" % firstProtein.getName()
		print "\t\t%i models" % proteins.size()
		print "\t\t%i subunits" % firstProtein.getSubunits().size()
		print "\t\t%i residues" % firstProtein.getSubunit( 0 ).getResidues().size()
	else:
		print "\tloaded %d null models!" % proteins.size()

	return proteins


def loadProteinsFromDir( inDir ):
	
	proteins = jvm.toArrayList( [] )
	for file in sorted( os.listdir( inDir ) ):
		print "Loading protein from:\t%s" % file
		proteins.addAll( f.pdb.ProteinReader().readAll( os.path.join( inDir, file ) ) )
		
	# map the names
	for i in range( 0, len( proteins ) ):
		f.mapping.NameMapper.ensureProtein( proteins.get( i ), f.mapping.NameScheme.New )
		
	return proteins
	

def loadProtein( file ):
	
	# read it
	print "Loading protein from:\n\t%s" % file
	protein = f.pdb.ProteinReader().read( file )
	
	# map the names
	f.mapping.NameMapper.ensureProtein( protein, f.mapping.NameScheme.New )
	
	# dump some stats
	print "\tloaded %s:" % protein.getName()
	print "\t\t%i subunits" % protein.getSubunits().size()
	print "\t\t%i residues" % protein.getSubunit( 0 ).getResidues().size()
	
	return protein


def writeProtein( file, protein ):
	
	f.pdb.ProteinWriter().write( protein, file )
	print "Wrote protein to:\n\t%s" % file


def writeProteins( file, proteins ):
	f.pdb.ProteinWriter().write( proteins, file )
	print "Wrote %d models to:\n\t%s" % ( len( proteins ), file )


def loadDistanceRestraints( file ):
	
	# read it
	print "Loading distance restraints from:\n\t%s" % file
	distanceRestraints = f.nmr.DistanceRestraintReader().read( file )
	print "\tloaded %i distance restraints" % distanceRestraints.size()
	
	return distanceRestraints


def writeDistanceRestraints( file, restraints ):
	
	f.nmr.DistanceRestraintWriter().write( file, restraints )
	print "Wrote %d distance restraints to:\n\t%s" % ( len( restraints ), file )

	
def interpretDistanceRestraintsWithPseudoatoms( restraints, sequences ):
	
	numChanged = f.pseudoatoms.PseudoatomBuilder.getInstance().buildDistanceRestraints( sequences, restraints )
	print "%i distance restraints interpreted with pseudoatoms" % numChanged


def uninterpretDistanceRestraintsWithPseudoatoms( restraints, sequences ):
	
	numChanged = f.pseudoatoms.PseudoatomBuilder.getInstance().unbuildDistanceRestraints( sequences, restraints )
	print "%i distance restraints uninterpreted with pseudoatoms" % numChanged


def mapDistanceRestraintsToProtein( readableRestraints, protein, addNulls = False ):
	
	print "Mapping Distance Restraints..."
	
	mapAtomNames( readableRestraints, protein.getSequences() )
	
	# add pseudoatoms to the protein if needed
	PseudoatomBuilder = f.pseudoatoms.PseudoatomBuilder
	if PseudoatomBuilder.distanceRestraintsHavePseudoatoms( readableRestraints ):
		if not PseudoatomBuilder.getInstance().hasPseudoatoms( protein ):
			print "\tBuilt %i Pseudoatoms" % PseudoatomBuilder.getInstance().build( protein )

	numRestraintsBefore = len( readableRestraints )
	internalRestraints = f.nmr.DistanceRestraintMapper.mapReadableToInternal( readableRestraints, protein, addNulls )
	numRestraintsAfter = len( internalRestraints )
	
	print "\tmapped %d readable restraints to %d internal restraints" % (
		numRestraintsBefore,
		numRestraintsAfter
	)
	
	# if restraints weren't mapped, show a warning!
	if numRestraintsBefore > numRestraintsAfter:
		
		numMissing = numRestraintsBefore - numRestraintsAfter
		print >> stderr, "%d distance restraints were not mapped!" % numMissing
		
		# show the unmapped restraints
		internalRestraintsWithNulls = f.nmr.DistanceRestraintMapper.mapReadableToInternal( readableRestraints, protein, True )
		if internalRestraintsWithNulls.size() != readableRestraints.size():
			raise Exception( "Distance restraint mapping error reporting failed!" )
		for i in range( readableRestraints.size() ):
			if internalRestraintsWithNulls.get( i ) is None:
				print >> stderr, readableRestraints.get( i ).toString()
	
	return internalRestraints


def unmapDistanceRestraintsFromProtein( restraints, protein, collapsePseudoatoms=False ):
	
	print "Unmapping distance restraints..."
	
	numRestraintsBefore = len( restraints )
	restraints = f.nmr.DistanceRestraintMapper.mapInternalToReadable( restraints, protein, collapsePseudoatoms )
	numRestraintsAfter = len( restraints )
	
	print "\tUnmapped %d internal restraints to %d readable restraints" % (
		numRestraintsBefore,
		numRestraintsAfter
	)
	
	return restraints


def filterRestraintsIntrasubunit( restraints, subunitId=None ):
	
	Filterer = f.nmr.DistanceRestraintFilterer
	print "Filtering %i restraints..." % len( restraints )
	restraints = Filterer.pickIntrasubunit( restraints )
	print "\tIntrasubunit Filter: %i" % len( restraints )
	if subunitId is not None:
		restraints = Filterer.pickBetween( restraints, subunitId, subunitId )
		print "\tSubunit Id %d Filter: %i" % ( subunitId, len( restraints ) )
	restraints = Filterer.pickUnique( restraints )
	print "\tUnique Filter: %i" % len( restraints )
	return restraints


def filterRestraintsIntersubunit( restraints ):
	
	Filterer = f.nmr.DistanceRestraintFilterer
	
	print "Filtering %i restraints..." % len( restraints )
	restraints = Filterer.pickIntersubunit( restraints )
	print "\tIntersubunit Filter: %i" % len( restraints )
	restraints = Filterer.pickSubunitEitherSide( restraints, 0 )
	print "\tFirst Subunit Filter: %i" % len( restraints )
	restraints = Filterer.pickUnique( restraints )
	print "\tUnique Filter: %i" % len( restraints )
	restraints = Filterer.pickOneFromSymmetricGroup( restraints )
	print "\tSymmetry Filter: %i" % len( restraints )
	return restraints


def mapRestraintsToSubunit( restraints, subunitId ):
	
	return f.nmr.DistanceRestraintFilterer.mapToSubunit( restraints, subunitId )


def reportMismatchedRestraints( restraints, numSubunits ):
	
	mismatchedRestraints = f.nmr.DistanceRestraintFilterer.pickMismatchedRestraints( restraints, numSubunits )
	if len( mismatchedRestraints ) > 0:
		print >> stderr, "Warning: %i mismatched restraints!" % len( mismatchedRestraints )
		for restraint in mismatchedRestraints:
			print >> stderr, "\t", restraint


def loadRdcs( file ):
	
	# read it
	print "Loading RDCs from:\n\t%s" % file
	rdcs = f.nmr.RdcReader().read( file )
	print "\tloaded %i RDCs" % rdcs.size()
	
	return rdcs


def mapRdcsToProtein( rdcs, protein ):
	
	print "Mapping RDCs..."
	
	mapAtomNames( rdcs, protein.getSequences() )
	
	numRdcsBefore = len( rdcs )
	internal = f.nmr.RdcMapper.mapReadableToInternal( protein, rdcs, False )
	numRdcsAfter = len( internal )
	
	print "\tmapped %d readable RDCs to %d internal RDCs" % (
		numRdcsBefore,
		numRdcsAfter
	)
	
	return internal


def loadShifts( file ):
	
	# read it
	print "Loading Chemical Shifts from:\n\t%s" % file
	shifts = f.nmr.ChemicalShiftReader().read( file )
	print "\tloaded %i Chemical Shifts" % shifts.size()
	
	return shifts


def interpretShiftsWithPseudoatoms( shifts, protein ):
	
	numChanged = f.pseudoatoms.PseudoatomBuilder.getInstance().buildShifts( protein, shifts )
	print "%i chemical shifts interpreted with pseudoatoms" % numChanged


def mapShifts( shifts, protein ):
	
	print "Mapping Chemical Shifts..."
	
	# map the names
	mapAtomNames( shifts, protein.getSequences() )
	
	mappedShifts = f.nmr.ChemicalShiftMapper.map( shifts, protein )
	print "\tmapped %i Chemical Shifts to %i Mapped Chemical Shifts" % ( shifts.size(), mappedShifts.size() )
	
	return mappedShifts


def loadDihedralRestraints( file ):
	
	# read it
	print "Loading Dihedral Restraints from:\n\t%s" % file
	restraints = f.nmr.DihedralRestraintReader().read( file )
	print "\tloaded %i Dihedral Restraints" % restraints.size()
	
	return restraints

def mapDihedralRestraintsToProtein( readableRestraints, protein, addNulls = False ):
	
	print "Mapping Dihedral Restraints..."
	
	mapAtomNames( readableRestraints, protein.getSequences() )
	
	numRestraintsBefore = len( readableRestraints )
	internalRestraints = f.nmr.DihedralRestraintMapper.mapReadableToInternal( readableRestraints, protein, addNulls )
	numRestraintsAfter = len( internalRestraints )
	
	print "\tmapped %d readable restraints to %d internal restraints" % (
		numRestraintsBefore,
		numRestraintsAfter
	)
	
	# if restraints weren't mapped, show a warning!
	if numRestraintsBefore > numRestraintsAfter:
		
		numMissing = numRestraintsBefore - numRestraintsAfter
		print >> stderr, "%d distance restraints were not mapped!" % numMissing
		
		# show the unmapped restraints
		internalRestraintsWithNulls = f.nmr.DihedralRestraintMapper.mapReadableToInternal( readableRestraints, protein, True )
		if internalRestraintsWithNulls.size() != readableRestraints.size():
			raise Exception( "Dihedral restraint mapping error reporting failed!" )
		for i in range( readableRestraints.size() ):
			if internalRestraintsWithNulls.get( i ) is None:
				print >> stderr, readableRestraints.get( i ).toString()
	
	return internalRestraints


def mapAtomNames( addresses, sequences ):
	
	f.mapping.NameMapper.ensureAddresses( sequences, addresses, f.mapping.NameScheme.New )


def mapAtomNamesForXplor( addresses, sequences ):
	
	f.mapping.NameMapper.ensureAddresses( sequences, addresses, f.mapping.NameScheme.Old )


def getRmsd( reference, compared ):
	
	return f.analysis.RmsdCalculator.getRmsd( reference, compared )


def getDistanceRestraintRmsd( structure, restraints ):
	
	return f.analysis.RestraintCalculator().getRmsd( structure, restraints )


def getDistanceRestraintNumSatisfied( structure, restraints ):
	
	return f.analysis.RestraintCalculator().getNumSatisfied( structure, restraints )


def getAverageVariance( proteins ):
	
	return f.analysis.VarianceCalculator.getAverageVariance( proteins )


def getAverageRmsd( proteins ):
	
	return f.analysis.VarianceCalculator.getAverageRmsd( proteins )

	
def printRestraintsStats( restraints ):
	
	# count ambiguous vs unambiguous
	numAmbiguous = 0
	numUnambiguous = 0
	for restraint in restraints:
		if restraint.isAmbiguous():
			numAmbiguous += 1
		else:
			numUnambiguous += 1
	
	print "Restraints stats:"
	print "\ttotal: %i" % len( restraints )
	print "\tunambiguous: %i" % numUnambiguous
	print "\tambiguous %i" % numAmbiguous
	

def alignOptimally( reference, compared ):
	
	f.protein.tools.ProteinGeometry.center( reference )
	f.analysis.StructureAligner.alignOptimally( reference, compared )
	
	
def alignOptimallyBySubunit( reference, compared, subunitId ):
	
	translation = f.protein.tools.ProteinGeometry.getCentroid( reference.getSubunit( subunitId ) )
	translation.negate()
	f.protein.tools.ProteinGeometry.translate( reference, translation )
	f.analysis.StructureAligner.alignOptimallyBySubunit( reference, compared, subunitId )


def alignOptimallyByAtoms( reference, compared, referenceAddresses, comparedAddresses ):
	
	f.protein.tools.ProteinGeometry.center( reference, referenceAddresses )
	f.analysis.StructureAligner.alignOptimallyByAtoms( reference, compared, referenceAddresses, comparedAddresses )


def alignEnsembleOptimally( ensemble ):
	
	f.analysis.StructureAligner.alignEnsembleOptimally( ensemble )


def alignEnsembleOptimallyByAtoms( ensemble, addresses ):
	
	f.analysis.StructureAligner.alignEnsembleOptimallyByAtoms( ensemble, addresses )


def alignEnsembleOptimallyToReference( reference, ensemble ):
	
	f.protein.tools.ProteinGeometry.center( reference )
	for structure in ensemble:
		f.analysis.StructureAligner.alignOptimally( reference, structure )


def alignEnsembleOptimallyToReferenceByAtoms( reference, ensemble, addresses ):
	
	f.protein.tools.ProteinGeometry.center( reference )
	for structure in ensemble:
		f.analysis.StructureAligner.alignOptimallyByAtoms( reference, structure, addresses, addresses )


def alignEnsembleOptimallyToReferenceBySubunit( reference, ensemble, subunitId ):
	
	translation = f.protein.tools.ProteinGeometry.getCentroid( reference.getSubunit( subunitId ) )
	translation.negate()
	f.protein.tools.ProteinGeometry.translate( reference, translation )
	for structure in ensemble:
		f.analysis.StructureAligner.alignOptimallyBySubunit( reference, structure, subunitId )


def saveEnsembleKinemage( fileOut, ensemble, reference=None, distanceRestraints=None, varySubunitColors=False ):
	
	kinemage = f.kinemage.Kinemage( "Ensemble" )
	KinemageBuilder = f.kinemage.KinemageBuilder
	KinemageBuilder.appendAxes( kinemage )
	
	# add the reference structure
	if reference != None:
		KinemageBuilder.appendBackbone( kinemage, reference, "Reference", 1, 3 )
		
	# add the aligned ensemble structures
	i = 1
	for structure in ensemble:
		
		if structure is not None:
			
			# build the color list
			colors = jvm.toArrayList( [] )
			for i in range( structure.getSubunits().size() ):
				colors.add( f.kinemage.KinemageColor.values()[i] )
			
			KinemageBuilder.appendBackbone( kinemage, structure, "Ensemble %d" % i, colors, 2 )
			
		i += 1
		
	# add the distance restraints if needed (using the first non-null structure)
	if distanceRestraints is not None:
		for structure in ensemble:
			if structure is not None:
				KinemageBuilder.appendDistanceRestraints( kinemage, structure, distanceRestraints )
				break
	
	# write it!
	f.kinemage.KinemageWriter().write( kinemage, fileOut )
	

def savePairKinemage( pathOut, reference, computed ):
	
	# build the kinemage
	kinemage = f.kinemage.Kinemage( "Pair" )
	KinemageBuilder = f.kinemage.KinemageBuilder
	KinemageBuilder.appendAxes( kinemage )
	KinemageBuilder.appendBackbone( kinemage, reference, "Reference", 1, 3 )
	KinemageBuilder.appendBackbone( kinemage, computed, "Computed", 0, 2 )
	
	# write it!
	f.kinemage.KinemageWriter().write( kinemage, pathOut )
	print "Kinemage with 2 structures written to:\n\t%s" % pathOut


def saveStructureKinemage( pathOut, protein, restraints=None ):
	
	# build the kinemage
	kinemage = f.kinemage.Kinemage( "Pair" )
	KinemageBuilder = f.kinemage.KinemageBuilder
	KinemageBuilder.appendAxes( kinemage )
	KinemageBuilder.appendProtein( kinemage, protein )
	
	if restraints != None:
		KinemageBuilder.appendDistanceRestraints( kinemage, protein, restraints )
	
	# write it!
	f.kinemage.KinemageWriter().write( kinemage, pathOut )
	print "Kinemage with 1 structure written to:\n\t%s" % pathOut


def runKing( pathKinemage ):
	
	os.system( "king \"%s\"" % pathKinemage )
	

def minimize( protein, distanceRestraints=None, logPath=None ):
	minimizer = f.xplor.StructureMinimizer()
	if distanceRestraints != None:
		minimizer.setDistanceRestraints( distanceRestraints )
	if logPath != None:
		minimizer.setLog( logPath )
	return minimizer.minimize( protein )


def calculateEnergy( protein, logPath=None ):
	
	calculator = f.xplor.EnergyCalculator()
	if logPath != None:
		calculator.setLog( logPath )
	return calculator.getEnergy( protein )


def newSubunitOrder( subunitOrder ):
	
	return f.protein.SubunitOrder( subunitOrder )


def newAlignmentTensor( protein, rdcs, numSamples=0 ):
	
	tensor = f.nmr.AlignmentTensor.compute( protein, rdcs )
	print tensor.getStats( protein, rdcs )
	return tensor


def newAlignmentTensors( protein, rdcs, numSamples ):
	
	return f.nmr.AlignmentTensor.compute( protein, rdcs, numSamples )

