#-------------------------------------------------------------------------------
# This library is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License as published by the Free Software Foundation; either
# version 2.1 of the License, or (at your option) any later version.
# 
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
# 
# You should have received a copy of the GNU Lesser General Public
# License along with this library; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
# USA
# 
# Contact Info:
# 	Bruce Donald
# 	Duke University
# 	Department of Computer Science
# 	Levine Science Research Center (LSRC)
# 	Durham
# 	NC 27708-0129 
# 	USA
# 	brd@cs.duke.edu
# 
# Copyright (C) 2011 Jeffrey W. Martin and Bruce R. Donald
# 
# <signature of Bruce Donald>, April 2011
# Bruce Donald, Professor of Computer Science
#-------------------------------------------------------------------------------


###########################################
##  More Settings
###########################################

WriteChartsInPng = True
WriteChartsInSvg = False
JvmMaxHeapSize = "1g"

###########################################


import math, subprocess, threading, os
from collections import deque
import jvm, jdshot, share


# start the jvm
jvmInstance = jvm.Jvm()
jvmInstance.addPath( "lib/disco.jar" )
#jvmInstance.addPath( "bin" )
#jvmInstance.addPath( "../share/bin" )
#jvmInstance.addLibsDir( "lib" )
#jvmInstance.addLibsDir( "../share/lib" )
share.initJvm( jvmInstance )
jdshot.initJvm( jvmInstance )
jvmInstance.start( "-Xmx%s" % JvmMaxHeapSize )


# java class shortcuts
Subunit = share.f.protein.Subunit
Protein = share.f.protein.Protein
PseudoatomBuilder = share.f.pseudoatoms.PseudoatomBuilder
AlignmentTensorAxis = share.f.nmr.AlignmentTensorAxis
MonomerCloner = share.f.protein.tools.MonomerCloner
Element = share.f.protein.Element
ChemicalShiftMapper = share.f.nmr.ChemicalShiftMapper
DistanceRestraintReassigner = share.f.nmr.DistanceRestraintReassigner

SymmetryContext = jdshot.f.context.SymmetryContext
DistanceRestraintsContext = jdshot.f.context.DistanceRestraintsContext
RdcsContext = jdshot.f.context.RdcsContext
CnGridPoint = jdshot.f.grid.cn.CnGridPoint

Plotter = jdshot.f.disco.Plotter
OrientationCalculator = jdshot.f.disco.OrientationCalculator
AnnulusCalculator = jdshot.f.disco.AnnulusCalculator
PositionCalculator = jdshot.f.disco.PositionCalculator
ViolationCalculator = jdshot.f.disco.ViolationCalculator
StructureEvaluator = jdshot.f.disco.StructureEvaluator
ResultsWriter = jdshot.f.disco.ResultsWriter
ResultsReader = jdshot.f.disco.ResultsReader
ChartWriter = jdshot.f.chart.ChartWriter
SubunitVariationAnalyzer = jdshot.f.disco.SubunitVariationAnalyzer


def run( outDir, pathInEnsemble, numSubunits, pathInRdcs, pathInTrustedDistanceRestraints, orientationSamplingResolution, positionSamplingResolution,
	modelId=0, alignmentTensorAxis=None, numAlignmentTensorSamples=10000, orientationSamplingPercentile=1.0, pathInUntrustedDistanceRestraints=None,
	numMinimizationProcesses=1, numMsrsProcesses=1, energySevereClashThreshold=10000.0, trustedPaddingPercent=None, untrustedPaddingPercent=None ):
	
	# check the alignment tensor axis
	if alignmentTensorAxis is None:
		if numSubunits > 2:
			alignmentTensorAxis = AlignmentTensorAxis.Z
		else:
			raise TypeError( "The alignment tensor axis must be specified for dimers. Run computeDimerAlignmentTensorAxis() to find the best axis." )
	
	# are we using untrusted restraints?
	usingUntrusted = pathInUntrustedDistanceRestraints is not None
	
	# set up input files
	fileEnsemble = jvm.newFile( pathInEnsemble )
	fileRdcs = jvm.newFile( pathInRdcs )
	fileTrustedDistanceRestraints = jvm.newFile( pathInTrustedDistanceRestraints )
	fileUntrustedDistanceRestraints = None
	if usingUntrusted:
		fileUntrustedDistanceRestraints = jvm.newFile( pathInUntrustedDistanceRestraints )
	
	# make the output folder if needed
	jvm.newFile( outDir ).mkdirs()
	
	# set up output files
	fileOrientations = jvm.newFile( outDir, "sampledOrientations.dat" )
	fileVariationTemplate = jvm.newFile( outDir, "variation.%s.dat" )
	fileAnnulusIndicesTemplate = jvm.newFile( outDir, "annulusIndices.%s.dat" )
	fileMsrsTemplate = jvm.newFile( outDir, "msrs.%s.%d.dat" )
	fileArrangementTemplate = jvm.newFile( outDir, "arrangement.%s.%d.png" )
	fileArrangementZoomedTemplate = jvm.newFile( outDir, "arrangement.zoomed.%s.%d.png" )
	fileInconsistentDistanceRestraintsTemplate = jvm.newFile( outDir, "inconsistentDistanceRestraints.%s.%d.png" )
	fileInconsistentDistanceRestraintsZoomedTemplate = jvm.newFile( outDir, "inconsistentDistanceRestraints.zoomed.%s.%d.png" )
	filePositionSamplingTemplate = jvm.newFile( outDir, "positionSampling.%d.png" )
	fileComputedStructures = jvm.newFile( outDir, "computedStructures.pdb" )
	fileMinimizedScores = jvm.newFile( outDir, "minimizedScores.dat" )
	fileMinimizedStructures = jvm.newFile( outDir, "minimizedStructures.pdb" )
	fileComputedStructureScores = jvm.newFile( outDir, "computeStructureScores.tsv" )
	fileMinimizedStructureScores = jvm.newFile( outDir, "minimizedStructureScores.tsv" )
	fileComputedStructureScoresPlot = jvm.newFile( outDir, "minimizedStructureScores.png" )
	fileMinimizedStructureScoresPlot = jvm.newFile( outDir, "minimizedStructureScores.png" )
	
	# load the inputs
	ensemble = loadSubunitEnsemble( fileEnsemble )
	subunit = ensemble.get( modelId )
	symmetryContext = SymmetryContext( jdshot.Cn, subunit, numSubunits )
	trustedDistanceRestraintsContext = DistanceRestraintsContext( symmetryContext, fileTrustedDistanceRestraints )
	rdcsContext = RdcsContext( symmetryContext, fileRdcs )
	
	# load the untrusted distance restraints if needed
	if usingUntrusted:
		untrustedDistanceRestraintsContext = DistanceRestraintsContext( symmetryContext, fileUntrustedDistanceRestraints )
	
	# sample orientations if needed
	if orientationSamplingResolution > 0:
		if fileOrientations.exists():
			print "Orientations already computed. Reading from file."
			orientations = ResultsReader().readOrientations( fileOrientations )
		else:
			orientations = getSampledOrientations(
				outDir,
				symmetryContext,
				rdcsContext,
				alignmentTensorAxis,
				orientationSamplingPercentile,
				orientationSamplingResolution,
				numAlignmentTensorSamples
			)
			ResultsWriter().writeOrientations( fileOrientations, orientations )
	else:
		bestTensor = share.newAlignmentTensor( symmetryContext.getMonomer(), rdcsContext.getInternalRdcs() )
		orientations = jvm.toArrayList( [ bestTensor.getAxis( alignmentTensorAxis ) ] )
	
	# apply the padding or variation
	applyPadding( trustedFile( fileVariationTemplate ), trustedDistanceRestraintsContext, trustedPaddingPercent, ensemble, subunit, numSubunits )
	if usingUntrusted:
		applyPadding( untrustedFile( fileVariationTemplate ), untrustedDistanceRestraintsContext, untrustedPaddingPercent, ensemble, subunit, numSubunits )
	
	# find out which trusted MSRs need to be computed
	# also compute the trusted annulus indices
	indicesToCompute = deque()
	trustedAnnulusIndices = jvm.toArrayList( [] )
	for i in range( 0, orientations.size() ):
		
		# have the trusted MSRs already been computed?
		if not ithFile( i, trustedFile( fileMsrsTemplate ) ).exists():
			indicesToCompute.append( i )
		
		# build the trusted annulus index
		trustedAnnulusIndices.add( AnnulusCalculator.computeAnnuli(
			getNormalizedSubunit( subunit, orientations.get( i ) ),
			trustedDistanceRestraintsContext.getInternalRestraints(),
			numSubunits
		) )
	
	ResultsWriter().writeAnnulusIndices( trustedFile( fileAnnulusIndicesTemplate ), trustedAnnulusIndices )
	
	# start the worker threads to compute all the MSRs
	print "Need to compute %d more sets of trusted MSRs..." % len( indicesToCompute )
	if len( indicesToCompute ) > 0 :
		computeAllMsrs(
			trustedFile( fileMsrsTemplate ),
			trustedFile( fileAnnulusIndicesTemplate ),
			indicesToCompute,
			numMsrsProcesses
		)
		
	# read in the trusted msrs
	print "Reading in all the trusted MSRs..."
	trustedMsrsList = []
	for i in range( 0, orientations.size() ):
		trustedMsrsList.append( ResultsReader().readMsrs( ithFile( i, trustedFile( fileMsrsTemplate ) ) ) )
	
	# plot all the arrangements and MSRs
	for i in range( 0, orientations.size() ):
		print "Analyzing trusted MSRs for orientation %s of %d..." % ( i + 1, orientations.size() )
		
		trustedMsrs = trustedMsrsList[i]
		trustedAnnulusIndex = trustedAnnulusIndices.get( i )
		
		# plot the arrangements
		writeChart( ithFile( i, trustedFile( fileArrangementTemplate ) ), Plotter.plotArrangement( trustedMsrs, trustedAnnulusIndex ) )
		writeChart( ithFile( i, trustedFile( fileArrangementZoomedTemplate ) ), Plotter.plotArrangement( trustedMsrs, trustedAnnulusIndex, True ) )
		
		reportConsistency(
			ithFile( i, trustedFile( fileInconsistentDistanceRestraintsTemplate ) ),
			ithFile( i, trustedFile( fileInconsistentDistanceRestraintsZoomedTemplate ) ),
			i,
			trustedMsrs,
			trustedAnnulusIndex,
			trustedDistanceRestraintsContext.getInternalRestraints(),
			symmetryContext,
			getNormalizedSubunit( subunit, orientations.get( i ) ),
		)
	
	# do we need to handle untrusted distance restraints?
	if usingUntrusted:
		
		# read the trusted MSRs and compute the untrusted annuli
		indicesToCompute = deque()
		untrustedAnnulusIndices = jvm.toArrayList( [] )
		for i in range( 0, orientations.size() ):
			
			# have the final MSRs already been computed?
			if not ithFile( i, finalFile( fileMsrsTemplate ) ).exists():
				indicesToCompute.append( i )
			
			# read the trusted MSRs and filter the untrusted restraints
			untrustedAnnulusIndices.add( AnnulusCalculator.computeUntrustedAnnuli(
				getNormalizedSubunit( subunit, orientations.get( i ) ),
				untrustedDistanceRestraintsContext.getInternalRestraints(),
				numSubunits,
				trustedMsrsList[i]
			) )
		
		ResultsWriter().writeAnnulusIndices( untrustedFile( fileAnnulusIndicesTemplate ), untrustedAnnulusIndices )
		
		# start the worker threads to compute all the final MSRs
		print "Need to compute %d more sets of final MSRs..." % len( indicesToCompute )
		if len( indicesToCompute ) > 0 :
			computeAllMsrs(
				finalFile( fileMsrsTemplate ),
				trustedFile( fileAnnulusIndicesTemplate ),
				indicesToCompute,
				numMsrsProcesses,
				untrustedFile( fileAnnulusIndicesTemplate )
			)
			
		# read in the final msrs
		print "Reading in all the final MSRs..."
		finalMsrsList = []
		for i in range( 0, orientations.size() ):
			finalMsrsList.append( ResultsReader().readMsrs( ithFile( i, finalFile( fileMsrsTemplate ) ) ) )
		
		# plot all the arrangements and MSRs
		for i in range( 0, orientations.size() ):
			print "Analyzing final MSRs for orientation %d of %d..." % ( i + 1, orientations.size() )
			
			finalMsrs = finalMsrsList[i]
			trustedAnnulusIndex = trustedAnnulusIndices.get( i )
			untrustedAnnulusIndex = untrustedAnnulusIndices.get( i )
			trustedRestraints = trustedDistanceRestraintsContext.getInternalRestraints()
			untrustedRestraints = untrustedDistanceRestraintsContext.getInternalRestraints()
			
			# plot the arrangements and MSRs
			writeChart( ithFile( i, finalFile( fileArrangementTemplate ) ), Plotter.plotArrangement( finalMsrs, trustedAnnulusIndex, untrustedAnnulusIndex ) )
			writeChart( ithFile( i, finalFile( fileArrangementZoomedTemplate ) ), Plotter.plotArrangement( finalMsrs, trustedAnnulusIndex, untrustedAnnulusIndex, True ) )
			
			print "Consistency of Trusted distance restraints:"
			reportConsistency(
				ithFile( i, finalTrustedFile( fileInconsistentDistanceRestraintsTemplate ) ),
				ithFile( i, finalTrustedFile( fileInconsistentDistanceRestraintsZoomedTemplate ) ),
				i,
				finalMsrs,
				trustedAnnulusIndex,
				trustedRestraints,
				symmetryContext,
				getNormalizedSubunit( subunit, orientations.get( i ) ),
			)
			
			print "Consistency of Untrusted distance restraints:"
			reportConsistency(
				ithFile( i, finalUntrustedFile( fileInconsistentDistanceRestraintsTemplate ) ),
				ithFile( i, finalUntrustedFile( fileInconsistentDistanceRestraintsZoomedTemplate ) ),
				i,
				finalMsrs,
				untrustedAnnulusIndex,
				untrustedRestraints,
				symmetryContext,
				getNormalizedSubunit( subunit, orientations.get( i ) ),
			)
	
	# which set of MSRs should we use to compute oligomer structures?
	if usingUntrusted:
		msrsList = finalMsrsList
		print "Computing oligomer structures from the final MSRs..."
	else:
		msrsList = trustedMsrsList
		print "Computing oligomer structures from the trusted MSRs..."
	
	allComputedStructures = jvm.toArrayList( [] )
	numSampledStructures = []
	
	# generate oligomer structures from the trusted/final MSRs
	for i in range( 0, orientations.size() ):
		
		# compute the discrete oligomer structures
		computedStructures = computeDiscreteStructures(
			ithFile( i, filePositionSamplingTemplate ),
			symmetryContext,
			getNormalizedSubunit( subunit, orientations.get( i ) ),
			msrsList[i],
			positionSamplingResolution
		)
		allComputedStructures.addAll( computedStructures )
		numSampledStructures.append( computedStructures.size() )
	
	print "Sampled %d total structures:" % allComputedStructures.size()
	for i in range( 0, orientations.size() ):
		print "\t%d structures from orientation %d" % ( numSampledStructures[i], i + 1 )
	
	# write out the computed structures in a PDB file
	print "Writing oligomeric ensemble to PDB file..."
	share.writeProteins( fileComputedStructures, allComputedStructures )
	
	# minimize the structures if xplor is available
	if os.name == 'nt':
		print ""
		print ""
		print "################################################################"
		print "  DISCO is complete!"
		print "  The final UNminimized oligomeric ensemble has been written to:"
		print "  %s" % fileComputedStructures.getAbsolutePath()
		print "  Structure minimization is currently unsupported in Windows"
		print "     Xplor-NIH is not available in Windows."
		print "     The structures computed by DISCO have not been minimized."
		print "     You may wish to minimize them using an external program."
		print "################################################################"
		return
	
	# use the original un-padded trusted restraints only
	unpaddedDistanceRestraintsContext = DistanceRestraintsContext( symmetryContext, fileTrustedDistanceRestraints )
	
	# minimize the structures and score them
	scores = StructureEvaluator.evaluate(
		symmetryContext,
		allComputedStructures,
		unpaddedDistanceRestraintsContext.getInternalRestraints(),
		numMinimizationProcesses
	)
	if scores.hasError():
		raise Exception( "Structure minimization failed. Aborting DISCO!" )
	ResultsWriter().writeScores( fileMinimizedScores, scores )
	scores.writeMinimizedStrcutures( fileMinimizedStructures )
	print "Wrote %d minimized oligomeric structures to:\n\t%s" % ( scores.getNumScores(), fileMinimizedStructures.getAbsolutePath() )
	
	# write out minimization results and structure scores
	scores.writeBeforeToFile( fileComputedStructureScores )
	scores.writeAfterToFile( fileMinimizedStructureScores )
	writeChart(
		fileComputedStructureScoresPlot,
		Plotter.plotStructureScores( scores.getBeforeSamples(), energySevereClashThreshold )
	)
	writeChart(
		fileMinimizedStructureScoresPlot,
		Plotter.plotStructureScores( scores.getAfterSamples(), energySevereClashThreshold )
	)
	
	# evaluate the atom position deviation
	minimizedStructures = share.loadProteins( fileMinimizedStructures )
	print "All-atom average variance %f A^2 or %f A RMSD" % (
		share.getAverageVariance( minimizedStructures ),
		share.getAverageRmsd( minimizedStructures )
	)
	backbones = jvm.toArrayList( [protein.getBackbone() for protein in minimizedStructures] )
	print "Backbone atom average variance %f A^2 or %f A RMSD" % (
		share.getAverageVariance( backbones ),
		share.getAverageRmsd( backbones )
	)
	
	print ""
	print ""
	print "################################################################"
	print "  DISCO is complete!"
	print "  The final minimized oligomeric ensemble has been written to:"
	print "  %s" % fileMinimizedStructures.getAbsolutePath()
	print "################################################################"


def computeDimerAlignmentTensorAxis( outDir, pathInEnsemble, numSubunits, pathInRdcs, pathInDistanceRestraints,
	modelId=0, numMsrsProcesses=1 ):
	
	# set up input files
	fileEnsemble = jvm.newFile( pathInEnsemble )
	fileRdcs = jvm.newFile( pathInRdcs )
	fileDistanceRestraints = jvm.newFile( pathInDistanceRestraints )
	
	# make the output folder if needed
	jvm.newFile( outDir ).mkdirs()
		
	# set up output files
	fileAnnulusIndices = jvm.newFile( outDir, "annulusIndices.dat" )
	fileMsrsTemplate = jvm.newFile( outDir, "msrs.%s.dat" )
	
	# load the inputs
	ensemble = loadSubunitEnsemble( fileEnsemble )
	subunit = ensemble.get( modelId )
	symmetryContext = SymmetryContext( jdshot.Cn, subunit, numSubunits )
	distanceRestraintsContext = DistanceRestraintsContext( symmetryContext, fileDistanceRestraints )
	rdcsContext = RdcsContext( symmetryContext, fileRdcs )
	
	# compute the alignment tensor
	tensor = share.newAlignmentTensor( subunit, rdcsContext.getInternalRdcs() )
	
	# compute the annulus indices for each orientation
	annulusIndices = jvm.toArrayList( [] )
	for axis in AlignmentTensorAxis.values():
		
		# normalize a copy of the subunit
		normalizedSubunit = Subunit( subunit )
		OrientationCalculator.normalizeProtein( normalizedSubunit, tensor.getAxis( axis ) )
		
		# compute the unions of annuli
		annulusIndices.add( AnnulusCalculator.computeAnnuli(
			normalizedSubunit,
			distanceRestraintsContext.getInternalRestraints(),
			numSubunits
		) )
	ResultsWriter().writeAnnulusIndices( fileAnnulusIndices, annulusIndices )
	
	# start the worker threads to compute all the MSRs
	computeAllMsrs(
		fileMsrsTemplate,
		fileAnnulusIndices,
		deque( range( annulusIndices.size() ) ),
		numMsrsProcesses
	)
	
	# read the computed MSRs and pick the one with the best distance restraint satisfaction
	satisfaction = dict()
	for i in range( 0, annulusIndices.size() ):
		tag = "%d" % ( i + 1 )
		print "Analyzing MSRs for axis %s of %d..." % ( tag, annulusIndices.size() )
		
		# compute the satisfaction from the MSRs
		fileMsrs = jvm.newFile( fileMsrsTemplate.getAbsolutePath() % tag )
		msrs = ResultsReader().readMsrs( fileMsrs )
		satisfaction[AlignmentTensorAxis.values()[i]] = PositionCalculator.getNumSatisfiedRestraints( msrs )
	
	# print out the satisfaction stats
	print ""
	print ""
	print "################################################################"
	print "  DISCO is complete!"
	maxNumSatisfiedRestraints = 0
	bestAxis = None
	for axis, numSatisfiedRestraints in satisfaction.iteritems():
		print "  %s Axis: %d satisfied distance restraints" % ( axis.name(), numSatisfiedRestraints )
		if numSatisfiedRestraints > maxNumSatisfiedRestraints:
			maxNumSatisfiedRestraints = numSatisfiedRestraints
			bestAxis = axis
	print "  Most satisfying axis: %s" % bestAxis.name()
	print "################################################################"


def reassignNoes( pathOutReassignedNoes, pathInSubunit, pathInNoes, pathInShifts, numSubunits, hydrogenWindowSize, carbonWindowSize, nitrogenWindowSize ):

	# set up output files
	fileOutReassignedNoes = jvm.newFile( pathOutReassignedNoes )
	
	# make the output folder if needed
	fileOutReassignedNoes.getParentFile().mkdirs()
	
	# load the protein and noe data
	subunit = share.loadProtein( pathInSubunit ).getSubunits().get( 0 )
	clonedSubunit = share.cloneMonomer( subunit, numSubunits )
	noes = share.loadDistanceRestraints( pathInNoes )
	share.interpretDistanceRestraintsWithPseudoatoms( noes, clonedSubunit )
	restraints = share.mapDistanceRestraintsToProtein( noes, clonedSubunit )
	share.printRestraintsStats( restraints )
	share.reportMismatchedRestraints( restraints, numSubunits )
	restraints = share.filterRestraintsIntersubunit( restraints )
	
	# load and process the chemical shifts
	shifts = share.loadShifts( pathInShifts )
	hydrogenShifts = ChemicalShiftMapper.filter( shifts, Element.Hydrogen )
	carbonShifts = ChemicalShiftMapper.filter( shifts, Element.Carbon )
	nitrogenShifts = ChemicalShiftMapper.filter( shifts, Element.Nitrogen )
	share.interpretShiftsWithPseudoatoms( hydrogenShifts, clonedSubunit )
	mappedHydrogenShifts = share.mapShifts( hydrogenShifts, clonedSubunit )
	mappedCarbonShifts = share.mapShifts( carbonShifts, clonedSubunit )
	mappedNitrogenShifts = share.mapShifts( nitrogenShifts, clonedSubunit )
	
	print "Num Hydrogen Shifts: ", mappedHydrogenShifts.size()
	print "Num Carbon Shifts: ", mappedCarbonShifts.size()
	print "Num Nitrogen Shifts: ", mappedNitrogenShifts.size()
	
	# combine the heavy shifts
	mappedHeavyShifts = jvm.toArrayList( [] )
	mappedHeavyShifts.addAll( mappedCarbonShifts )
	mappedHeavyShifts.addAll( mappedNitrogenShifts )
	
	carbonPairs = ChemicalShiftMapper.associatePairs( clonedSubunit, mappedHydrogenShifts, mappedCarbonShifts, Element.Carbon )
	nitrogenPairs = ChemicalShiftMapper.associatePairs( clonedSubunit, mappedHydrogenShifts, mappedNitrogenShifts, Element.Nitrogen )
	orphanedHydrogenShifts = ChemicalShiftMapper.getOrphanedHydrogenShifts( clonedSubunit, mappedHydrogenShifts, mappedHeavyShifts )
	
	print "Num Carbon Pairs", carbonPairs.size()
	print "Num Nitrogen Pairs", nitrogenPairs.size()
	print "Num Orphaned Hydrogen Shifts", orphanedHydrogenShifts.size()
	
	# reassign the NOEs
	print "Reassigning NOEs..."
	print "\tAssignments before: total=%i, avg=%f" % ( totalNumberOfAssignments( restraints ), avgNumberOfAssignments( restraints ) )
	#reassignedRestraints = DistanceRestraintReassigner.reassign1D( clonedSubunit, restraints, mappedHydrogenShifts, hydrogenWindowSize )
	reassignedRestraints = DistanceRestraintReassigner.reassignDouble3D(
		clonedSubunit,
		restraints,
		mappedHydrogenShifts,
		carbonPairs,
		nitrogenPairs,
		hydrogenWindowSize,
		carbonWindowSize,
		nitrogenWindowSize
	)
	print "\tAssignments after (expanded pseudoatoms): total=%i, avg=%f" % ( totalNumberOfAssignments( reassignedRestraints ), avgNumberOfAssignments( reassignedRestraints ) )
	
	# unmap and save
	reassignedNoes = share.unmapDistanceRestraintsFromProtein( reassignedRestraints, clonedSubunit, True )
	share.writeDistanceRestraints( fileOutReassignedNoes, reassignedNoes )
	
	print "\tAssignments after (collapsed pseudoatoms): total=%i, avg=%f" % ( totalNumberOfAssignments( reassignedNoes ), avgNumberOfAssignments( reassignedNoes ) )
	
	print ""
	print ""
	print "################################################################"
	print "  DISCO is complete!"
	print "  The final reassigned NOEs have been written to:"
	print "  %s" % fileOutReassignedNoes.getAbsolutePath()
	print "################################################################"

	
	
def computeAllMsrs( fileMsrsTemplate, fileAnnulusIndices, indicesToCompute, numMsrsProcesses, fileUntrustedAnnulusIndices = None ):
	
	indicesLock = threading.Lock()
	threadList = []
	for i in range( 0, numMsrsProcesses ):
		threadArgs = (
			i + 1,
			indicesLock,
			indicesToCompute,
			fileMsrsTemplate,
			fileAnnulusIndices, 
			len( indicesToCompute ),
			fileUntrustedAnnulusIndices
		)
		thread = threading.Thread( target = msrsWorker, args = threadArgs )
		threadList.append( thread )
	print "Starting %d workers to compute MSRs..." % len( threadList )
	for thread in threadList:
		thread.start()
	for thread in threadList:
		thread.join()
	print "All MSRs computed!"


def msrsWorker( workerNumber, indicesLock, indicesToCompute, fileMsrsTemplate, fileAnnulusIndices, numOrientations, fileUntrustedAnnulusIndices = None ):
	
	# enable access to the jvm
	jvmInstance.attachThread()
	
	while True:
		
		# get the next index to compute
		orientationId = None
		indicesLock.acquire()
		if len( indicesToCompute ) > 0:
			orientationId = indicesToCompute.popleft()
		indicesLock.release()
		if orientationId is None:
			break;
		
		# run the MSRs computation in a separate process to keep the C++ code from leaking all our memory
		print "Worker %s: Computing MSRs for orientation %d of %d..." % ( workerNumber, orientationId + 1, numOrientations )
		args = [
			"java",
			"-cp",
			jvmInstance.getClasspath(),
			"-Xmx256m", 
			"edu.duke.donaldLab.jdshot.DiscoMsrMain",
			ithFile( orientationId, fileMsrsTemplate ).getAbsolutePath(),
			fileAnnulusIndices.getAbsolutePath(),
			"%d" % orientationId
		]
		if fileUntrustedAnnulusIndices is not None:
			args.append( "-u" )
			args.append( fileUntrustedAnnulusIndices.getAbsolutePath() )
		process = subprocess.Popen( args, stdout = subprocess.PIPE, stderr = subprocess.PIPE )
		
		# monitor stdout and stderr
		threadStdout = threading.Thread( target = pipeWorker, args = ( workerNumber, process.stdout ) )
		threadStderr = threading.Thread( target = pipeWorker, args = ( workerNumber, process.stderr ) )
		threadStdout.start()
		threadStderr.start()
		threadStdout.join()
		threadStderr.join()
		
		# check the process return code
		process.wait()
		if process.returncode != 0:
			print "Msrs computation worker %d ended with an error!" % workerNumber
		

def pipeWorker( workerNumber, pipe ):

		while True:
			line = pipe.readline()
			if len( line ) <= 0:
				break
			line = line.strip()
			print "Worker %d: %s" % ( workerNumber, line )


def getSampledOrientations( outDir, symmetryContext, rdcsContext, axis, percentile, resolution, numSamples ):
	
	bestTensor = share.newAlignmentTensor( symmetryContext.getMonomer(), rdcsContext.getInternalRdcs() )
	
	# estimate the distribution of alignment tensor axes
	tensors = share.newAlignmentTensors( symmetryContext.getMonomer(), rdcsContext.getInternalRdcs(), numSamples )
	writeChart(
		jvm.newFile( outDir, "sampledTensorAxes.png" ),
		Plotter.plotAlignmentTensorsAxis( tensors, axis )
	)
	
	tensors = OrientationCalculator.pickTensorCluster( tensors, bestTensor, axis )
	writeChart(
		jvm.newFile( outDir, "sampledTensorAxesDeviation.png" ),
		Plotter.plotAlignmentTensorAxisDeviation( tensors, bestTensor, axis )
	)
	
	# what is the max angle deviation?
	print "Max angle between sampled tensor and best tensor axis is %0.1f deg" % (
		math.degrees( OrientationCalculator.getMaxAngleDeviationRadians( tensors, bestTensor, axis ) )
	)
	
	# estimate an elliptical cone around the axes
	cone = OrientationCalculator.getEllipticalConeByPercentile( tensors, bestTensor, axis, percentile )
	
	# sample the cone uniformly
	maxAngle = math.degrees( max( cone.xMaxAngleRadians, cone.yMaxAngleRadians ) )
	print "Max angle between sampled tensor and best tensor axis is %.1f deg at %.2f percentile" % ( maxAngle, percentile * 100.0 )
	orientations = OrientationCalculator.computeOrientationsInRange( cone, resolution )
	print "Sampled %d representative orientations" % orientations.size()
	writeChart(
		jvm.newFile( outDir, "sampledOrientations.png" ),
		Plotter.plotOrientationSampling( tensors, bestTensor, axis, orientations )
	)
	
	return orientations


def applyPadding( fileVariation, distanceRestraintsContext, paddingPercent, ensemble, subunit, numSubunits ):
	
	# apply the padding
	restraints = distanceRestraintsContext.getInternalRestraints()
	if paddingPercent is None or paddingPercent < 0:
		variation = SubunitVariationAnalyzer.getVariation( ensemble, distanceRestraintsContext )
		ResultsWriter().writeVariation( fileVariation, variation )
		SubunitVariationAnalyzer.applyVariation( variation, restraints )
	elif paddingPercent > 0:
		share.padRestraints( restraints, paddingPercent )
	
	# put the modified restraints back into the distance restraints context
	distanceRestraintsContext.setInternalRestraints(
		restraints,
		MonomerCloner.clone( subunit, numSubunits )
	)


def reportConsistency( fileInconsistentDistanceRestraints, fileInconsistentDistanceRestraintsZoomed, orientationId, msrs, annulusIndex, restraints, symmetryContext, normalizedSubunit ):
	
	# how many restraints were satisfied?
	consistentRestraints = PositionCalculator.getConsistentRestraints( msrs, restraints )
	print "\tOrientation %d satisfies %d of %d distance restraints" % (
		orientationId + 1,
		consistentRestraints.size(),
		annulusIndex.getNumRestraints()
	)
	
	if restraints.size() != annulusIndex.getNumRestraints():
		print "\t\t(MSRs were computed using only %d of the total %d distance restraints)" % ( annulusIndex.getNumRestraints(), restraints.size() )
	
	# check for unsatisfiable restraints
	AnnulusCalculator.reportUnsatisfiableRestraints(
		annulusIndex,
		symmetryContext.getPseudoMonomer(),
		symmetryContext.getNumSubunits()
	)
	
	# get the inconsistent restraints
	inconsistentRestraints = PositionCalculator.getInconsistentRestraints( msrs, restraints )
	
	if inconsistentRestraints.size() <= 0:
		print "\t\tAll distance restraints are consistent!"
		return
	
	print "\t\tComputed %d inconsistent distance restraints:" % inconsistentRestraints.size()
	violations = ViolationCalculator.getViolations(
		msrs,
		annulusIndex,
		inconsistentRestraints,
		symmetryContext,
		normalizedSubunit
	)
	
	if not violations.isEmpty():
		
		print "\t\t\t(Note: each distance restraint assignment without a corresponding annulus was ignored)"
		print "\t\t\tAnnulus distance to MSRs (in A)\tAtom position violation (in A)\tDistance restraint"
		for violation in violations:
			print "\t\t\t%f\t%f\t%s" % (
				violation.distance,
				violation.atomDistance,
				violation.getAssignString()
			)
		writeChart(
			fileInconsistentDistanceRestraints,
			Plotter.plotViolations( msrs, violations, annulusIndex, inconsistentRestraints )
		)
		writeChart(
			fileInconsistentDistanceRestraintsZoomed,
			Plotter.plotViolations( msrs, violations, annulusIndex, inconsistentRestraints, True )
		)


def computeDiscreteStructures( fileSampling, symmetryContext, normalizedSubunit, msrs, sampleResolution ):
	
	# sample symmetry axis positions from the MSRs
	sampledAxisPositions = PositionCalculator.sampleMsrs( msrs, sampleResolution )
	if sampledAxisPositions.size() > 0:
		writeChart(
			fileSampling,
			Plotter.plotSampling( msrs, sampledAxisPositions )
		)
	
	# generate discrete oligomer structures
	computedStructures = jvm.toArrayList( [] )
	for axisPosition in sampledAxisPositions:
		computedStructures.add( symmetryContext.getOligomer(
			CnGridPoint( axisPosition.x, axisPosition.y, 0.0, 0.0 ),
			normalizedSubunit
		) )
	
	return computedStructures


def loadSubunitEnsemble( fileEnsemble ):
	
	proteins = share.loadProteins( fileEnsemble )
	subunits = jvm.toArrayList( [] )
	for protein in proteins:
		PseudoatomBuilder.getInstance().build( protein )
		subunits.add( protein.getSubunit( 0 ) )
	return subunits


def getNormalizedSubunit( subunit, orientation ):
	
	# only modify a copy of the subunit
	normalizedSubunit = Subunit( subunit )
	OrientationCalculator.normalizeProtein( normalizedSubunit, orientation );
	return normalizedSubunit


def totalNumberOfAssignments( restraints ):
	
	sum = 0
	for restraint in restraints:
		sum += restraint.getNumAssignments()
	return sum

	
def avgNumberOfAssignments( restraints ):
	
	return float( totalNumberOfAssignments( restraints ) ) / float( restraints.size() )


def applyFileTag( file, tags ):
	
	# escape all but the first %
	path = file.getAbsolutePath().replace( "%", "%%" ).replace( "%%", "%", 1 )
	
	return jvm.newFile( path % tags )


def trustedFile( file ):
	
	return applyFileTag( file, "trusted" )


def untrustedFile( file ):
	
	return applyFileTag( file, "untrusted" )


def finalFile( file ):
	
	return applyFileTag( file, "final" )


def finalTrustedFile( file ):
	
	return applyFileTag( file, "finalTrusted" )


def finalUntrustedFile( file ):
	
	return applyFileTag( file, "finalUntrusted" )


def ithFile( i, file ):
	
	return applyFileTag( file, i + 1 )

	
def writeChart( file, chart ):
	
	if WriteChartsInPng:
		# write in raster format
		ChartWriter.writePng( chart, file )
	
	if WriteChartsInSvg:
		# write in vector format if needed
		ChartWriter.writeSvg( chart, file )

