#include "spectral_pairs.h"
#include "spectral_alignment.h"
#include "batch.h"
#include "alignment_scoring.h"

void MergeIntoReference(Spectrum &refSpec, Spectrum &newPeaks, float mergeTol, float resolution, Spectrum &result) {
	result.resize(refSpec.size()+newPeaks.size());
	unsigned int refIdx=0, newIdx=0, resultIdx=0,
	             pivotIdx;
	float penaltyFactor;

	while(refIdx<refSpec.size() or newIdx<newPeaks.size()) {
		if(refIdx==refSpec.size() or (newIdx<newPeaks.size() and (int)round(newPeaks[newIdx][0]/resolution)<(int)round(refSpec[refIdx][0]/resolution)))
			{ result[resultIdx++] = newPeaks[newIdx++]; continue; }

		for(pivotIdx = (newIdx>0?newIdx-1:0); pivotIdx>0 and newPeaks[pivotIdx][0]>=refSpec[refIdx][0]-mergeTol; pivotIdx--);
		if(newPeaks[pivotIdx][0]<refSpec[refIdx][0]-mergeTol) pivotIdx++;
		for(; pivotIdx<newPeaks.size() and newPeaks[pivotIdx][0]<=refSpec[refIdx][0]+mergeTol; pivotIdx++) {
			penaltyFactor = 1 - 0.9*(fabs(refSpec[refIdx][0]-newPeaks[pivotIdx][0])/mergeTol);
			refSpec[refIdx][1]+=newPeaks[pivotIdx][1]*penaltyFactor;
		}
		
		int roundedRefMass = (int)round(refSpec[refIdx][0]/resolution);
		while(newIdx<newPeaks.size() and (int)round(newPeaks[newIdx][0]/resolution)==roundedRefMass)
			newIdx++;

		result[resultIdx++] = refSpec[refIdx++];
	}
	result.resize(resultIdx);
}

void SplitSpectra(SpecSet &specSet, SpecSet &specSetSplit) {
	specSetSplit.resize(2*specSet.size());
	for(unsigned int i=0; i<specSet.size(); i++) {
		specSetSplit[2*i] = specSet[i];
		specSet[i].reverse(0, &specSetSplit[2*i+1]);
	}
}

void SplitAligns(SpecSet &specSetSplit, vector<Results_ASP> &aligns,
				float peakTol, int maxAAjump, float penalty_sameVert, float penalty_ptm, bool forceSymmetry,
				vector<Results_ASP> &alignsNew, vector<vector<TwoValues<int> > > &matches, 
				vector<bool> &pairFlipped, vector<vector<float> > *dbg_matchScores, ofstream *debug) {
	TwoValues<float> matchScore(0,0), matchScore1(0,0), matchScore2(0,0);

	if(dbg_matchScores!=0) {
		(*dbg_matchScores).resize(aligns.size());
		for(unsigned int i=0; i<aligns.size(); i++) (*dbg_matchScores)[i].resize(4);
	}
	
	int spec1, spec1rev, spec2, spec2rev, newIdx;
	vector<Spectrum> results(4);
	alignsNew.resize(2*aligns.size());    matches.resize(2*aligns.size());   pairFlipped.resize(aligns.size());
	for(unsigned int alignIdx=0; alignIdx<aligns.size(); alignIdx++) {
		spec1 = 2*aligns[alignIdx].spec1;   spec1rev = 2*aligns[alignIdx].spec1+1;
		spec2 = 2*aligns[alignIdx].spec2;   spec2rev = 2*aligns[alignIdx].spec2+1;
		newIdx = 2*alignIdx;                matchScore.set(0,0);
		matchScore1.set(0,0);               matchScore2.set(0,0);

		spec_align(&specSetSplit[spec1],&specSetSplit[spec2],peakTol,&results[0],&results[1],maxAAjump,penalty_sameVert,penalty_ptm,forceSymmetry);
		for(unsigned int j=0; j<results[0].size();j++) {
			matchScore[0]+=results[0][j][1]+results[1][j][1];
			matchScore1[0]+=results[0][j][1];
			matchScore2[0]+=results[1][j][1];
		}

		spec_align(&specSetSplit[spec1rev],&specSetSplit[spec2],peakTol,&results[2],&results[3],maxAAjump,penalty_sameVert,penalty_ptm,forceSymmetry);
		for(unsigned int j=0; j<results[2].size();j++) {
			matchScore[1]+=results[2][j][1]+results[3][j][1];
			matchScore1[1]+=results[2][j][1];
			matchScore2[1]+=results[3][j][1];
		}

		if(dbg_matchScores!=0) { 
			(*dbg_matchScores)[alignIdx][0]=matchScore1[0];   (*dbg_matchScores)[alignIdx][1]=matchScore1[1];
			(*dbg_matchScores)[alignIdx][2]=matchScore2[0];   (*dbg_matchScores)[alignIdx][3]=matchScore2[1];
		}

		if(matchScore[1]>matchScore[0]) {
			spec_align(&specSetSplit[spec1],&specSetSplit[spec2rev],peakTol,&results[0],&results[1],maxAAjump,penalty_sameVert,penalty_ptm,forceSymmetry);
			alignsNew[newIdx].spec1 = spec1;        alignsNew[newIdx].spec2 = spec2rev;
			alignsNew[newIdx+1].spec1 = spec1rev;   alignsNew[newIdx+1].spec2 = spec2;
			pairFlipped[alignIdx] = true;
		} else {
			spec_align(&specSetSplit[spec1rev],&specSetSplit[spec2rev],peakTol,&results[2],&results[3],maxAAjump,penalty_sameVert,penalty_ptm,forceSymmetry);
			alignsNew[newIdx].spec1 = spec1;        alignsNew[newIdx].spec2 = spec2;
			alignsNew[newIdx+1].spec1 = spec1rev;   alignsNew[newIdx+1].spec2 = spec2rev;
			pairFlipped[alignIdx] = false;
		}
		
		vector<int> indices;
		for(unsigned int j=0; j<2; j++) {
			specSetSplit[alignsNew[newIdx+j].spec1].massesToIndices(results[2*j].peakList, indices, peakTol);
			matches[newIdx+j].resize(indices.size());
			for(unsigned int i=0; i<indices.size(); i++) matches[newIdx+j][i][0]=indices[i];
			specSetSplit[alignsNew[newIdx+j].spec2].massesToIndices(results[2*j+1].peakList, indices, peakTol);
			for(unsigned int i=0; i<indices.size(); i++) matches[newIdx+j][i][1]=indices[i];
		}
	}
}

void SplitLabels(vector<SpectrumPeakLabels> &labels, vector<SpectrumPeakLabels> &newLabels) {
	newLabels.resize(2*labels.size());
	for(unsigned int i=0; i<labels.size(); i++) {
		newLabels[2*i] = labels[i];   newLabels[2*i+1] = labels[i];   newLabels[2*i+1].reverse();
	}
}

void SplitPairs3(SpecSet &specSet, vector<Results_ASP> &aligns, vector<Results_PA> &alignsPA,
				float peakTol, int maxAAjump, float penalty_sameVert, float penalty_ptm,
				vector<vector<TwoValues<int> > > &matches, vector<vector<TwoValues<int> > > &matchesPA, 
				vector<bool> &specFlipped, vector<float> &modPos,
				bool forceSymmetry, vector<SpectrumPeakLabels> *labelsP,
                vector<TwoValues<float> > *alignRatios, vector<TwoValues<float> > *alignRatiosPA) {
	vector<TwoValues<list<int> > > alignsEntries(specSet.size());
	vector<float> flipScores(specSet.size());
	SpecSet specSetRev; specSetRev.resize(specSet.size());
	list<int> toProcess;
	TwoValues<vector<vector<TwoValues<int> > > > tmpMatches;
	TwoValues<vector<vector<TwoValues<int> > > > tmpMatchesPA;
	vector<TwoValues<float> > tmpModPos(aligns.size());
	vector<bool> matchComputed(aligns.size());
	vector<bool> matchPAComputed(alignsPA.size());
	matches.resize(aligns.size());      modPos.resize(aligns.size());
	matchesPA.resize(alignsPA.size());
	vector<int> idx1, idx2; idx1.reserve(500); idx2.reserve(500);

	tmpMatches[0].resize(aligns.size()); tmpMatches[1].resize(aligns.size());  
	for(unsigned int i=0;i<aligns.size();i++) {
		if(specSet[aligns[i].spec1].size()!=specSetRev[aligns[i].spec1].size()) 
			{ specSet[aligns[i].spec1].reverse(0, &specSetRev[aligns[i].spec1]); toProcess.push_back(aligns[i].spec1); }
		if(specSet[aligns[i].spec2].size()!=specSetRev[aligns[i].spec2].size()) 
			{ specSet[aligns[i].spec2].reverse(0, &specSetRev[aligns[i].spec2]); toProcess.push_back(aligns[i].spec2); }
		matchComputed[i] = false;
		alignsEntries[aligns[i].spec1][0].push_back(i);   alignsEntries[aligns[i].spec2][0].push_back(i);
	}
	tmpMatchesPA[0].resize(alignsPA.size()); tmpMatchesPA[1].resize(alignsPA.size());  
	for(unsigned int i=0;i<alignsPA.size();i++) {
		if(specSet[alignsPA[i].spec1].size()!=specSetRev[alignsPA[i].spec1].size()) 
			{ specSet[alignsPA[i].spec1].reverse(0, &specSetRev[alignsPA[i].spec1]); toProcess.push_back(alignsPA[i].spec1); }
		if(specSet[alignsPA[i].spec2].size()!=specSetRev[alignsPA[i].spec2].size()) 
			{ specSet[alignsPA[i].spec2].reverse(0, &specSetRev[alignsPA[i].spec2]); toProcess.push_back(alignsPA[i].spec2); }
		matchPAComputed[i] = false;
		alignsEntries[alignsPA[i].spec1][1].push_back(i);   alignsEntries[alignsPA[i].spec2][1].push_back(i);
	}
	list<int>::iterator pIter;
	for(pIter=toProcess.begin(); pIter!=toProcess.end(); pIter++) flipScores[*pIter]=0;

	vector<float> specScores(specSet.size());
	float bestScore, curScore; int bestScoreIdx=-1, specToAdd=-1;
	for(unsigned int i=0;i<specSet.size(); i++) { specScores[i]=0;
		for(unsigned int j=0; j<specSet[i].size(); j++) specScores[i]+=specSet[i][j][1]; }
	for(unsigned int i=0;i<aligns.size(); i++) {
		curScore=min(aligns[i].score1/specScores[aligns[i].spec1],aligns[i].score2/specScores[aligns[i].spec2]);
		if(bestScoreIdx<0 or bestScore<curScore) { bestScoreIdx=i; bestScore=curScore; specToAdd=aligns[bestScoreIdx].spec1; }
	}
	for(unsigned int i=0;i<alignsPA.size(); i++) {
		curScore=min(alignsPA[i].score1/specScores[alignsPA[i].spec1],alignsPA[i].score2/specScores[alignsPA[i].spec2]);
		if(bestScoreIdx<0 or bestScore<curScore) { bestScoreIdx=i; bestScore=curScore; specToAdd=alignsPA[bestScoreIdx].spec1; }
	}

	flipScores[specToAdd]=-1;
	pIter=toProcess.begin(); while(*pIter!=specToAdd) pIter++; toProcess.erase(pIter);
	while(specToAdd>=0) {
		int flipDir=0;
		if(flipScores[specToAdd]>0) { 
			specSet[specToAdd]=specSetRev[specToAdd]; flipDir=1; specFlipped[specToAdd]=true; 
			if(labelsP) (*labelsP)[specToAdd].reverse();
		}
		
		for(pIter = alignsEntries[specToAdd][0].begin(); pIter!=alignsEntries[specToAdd][0].end(); pIter++) {
			if(!matchComputed[*pIter]) {
			   	int otherSpec, otherSpecPos;
			   	if(aligns[*pIter].spec1==specToAdd) { otherSpec=aligns[*pIter].spec2; otherSpecPos=1; }
			   	else { otherSpec=aligns[*pIter].spec1; otherSpecPos=0; }

				TwoValues<float> matchScore1(0,0),matchScore2(0,0);
				vector<int> indices;
				vector<Spectrum> results(4);

				tmpModPos[*pIter][0] = spec_align(&specSet[specToAdd],&specSet[otherSpec],peakTol,&results[0],&results[1],maxAAjump,penalty_sameVert,penalty_ptm,forceSymmetry,true);
				for(unsigned int i=0; i<results[0].size(); i++) { matchScore1[0]+=results[0][i][1]; matchScore2[0]+=results[1][i][1]; }
				specSet[specToAdd].massesToIndices(results[0].peakList, indices, peakTol);
				tmpMatches[0][*pIter].resize(indices.size());
				for(unsigned int i=0; i<indices.size(); i++) tmpMatches[0][*pIter][i][1-otherSpecPos]=indices[i];
				specSet[otherSpec].massesToIndices(results[1].peakList, indices, peakTol);
				for(unsigned int i=0; i<indices.size(); i++) tmpMatches[0][*pIter][i][otherSpecPos]=indices[i];

				tmpModPos[*pIter][1] = spec_align(&specSet[specToAdd],&specSetRev[otherSpec],peakTol,&results[2],&results[3],maxAAjump,penalty_sameVert,penalty_ptm,forceSymmetry,true);
				for(unsigned int i=0; i<results[2].size(); i++) { matchScore1[1]+=results[2][i][1]; matchScore2[1]+=results[3][i][1]; }
				specSet[specToAdd].massesToIndices(results[2].peakList, indices, peakTol);
				tmpMatches[1][*pIter].resize(indices.size());
				for(unsigned int i=0; i<indices.size(); i++) tmpMatches[1][*pIter][i][1-otherSpecPos]=indices[i];
				specSetRev[otherSpec].massesToIndices(results[3].peakList, indices, peakTol);
				for(unsigned int i=0; i<indices.size(); i++) tmpMatches[1][*pIter][i][otherSpecPos]=indices[i];
			
				flipScores[otherSpec] += (matchScore2[1]-matchScore2[0])/specScores[otherSpec] + (matchScore1[1]-matchScore1[0])/specScores[specToAdd];
                                if(alignRatios) (*alignRatios)[*pIter][0]=min(matchScore1[0]/specScores[specToAdd], matchScore2[0]/specScores[otherSpec]);
                                if(alignRatios) (*alignRatios)[*pIter][1]=min(matchScore1[1]/specScores[specToAdd], matchScore2[1]/specScores[otherSpec]);
				matchComputed[*pIter] = true;
			} else {
				matches[*pIter].resize(tmpMatches[flipDir][*pIter].size());
				for(unsigned int i=0;i<matches[*pIter].size();i++) matches[*pIter][i]=tmpMatches[flipDir][*pIter][i];
				modPos[*pIter] = tmpModPos[*pIter][flipDir];
                                if(alignRatios and flipDir==1) { float tmp=(*alignRatios)[*pIter][0]; (*alignRatios)[*pIter][0]=(*alignRatios)[*pIter][1]; (*alignRatios)[*pIter][1]=tmp; }
			}
		}
		
		for(pIter = alignsEntries[specToAdd][1].begin(); pIter!=alignsEntries[specToAdd][1].end(); pIter++) {
			if(!matchPAComputed[*pIter]) {
			   	int otherSpec, otherSpecPos;
			   	if(alignsPA[*pIter].spec1==specToAdd) { otherSpec=alignsPA[*pIter].spec2; otherSpecPos=1; }
			   	else { otherSpec=alignsPA[*pIter].spec1; otherSpecPos=0; }

				TwoValues<float> score1, score2, matchScore1, matchScore2;
				vector<int> matchA1, matchA2, matchB1, matchB2;

	            FindMatchPeaksAll(specSet[specToAdd], specSet[otherSpec], alignsPA[*pIter].shift1, peakTol, idx1, idx2);
				ScoreOverlap6(specSet[specToAdd], idx1, specSet[otherSpec], idx2, alignsPA[*pIter].shift1, peakTol, matchA1, matchA2);
				score1.set(0,0); for(unsigned int i=0; i<matchA1.size();i++) { score1[0]+=specSet[specToAdd][matchA1[i]][1]; score1[1]+=specSet[otherSpec][matchA2[i]][1]; }
	            FindMatchPeaksAll(specSet[specToAdd], specSet[otherSpec], alignsPA[*pIter].shift2, peakTol, idx1, idx2);
				ScoreOverlap6(specSet[specToAdd], idx1, specSet[otherSpec], idx2, alignsPA[*pIter].shift2, peakTol, matchB1, matchB2);
				score2.set(0,0); for(unsigned int i=0; i<matchB1.size();i++) { score2[0]+=specSet[specToAdd][matchB1[i]][1]; score2[1]+=specSet[otherSpec][matchB2[i]][1]; }
				if( (score1[0]/specScores[specToAdd]+score1[1]/specScores[otherSpec]) > (score2[0]/specScores[specToAdd]+score2[1]/specScores[otherSpec]) ) 
					{ matchScore1[0]=score1[0]; matchScore2[0]=score1[1]; tmpMatchesPA[0][*pIter].resize(matchA1.size()); for(unsigned int i=0;i<matchA1.size(); i++) { tmpMatchesPA[0][*pIter][i][1-otherSpecPos]=matchA1[i]; tmpMatchesPA[0][*pIter][i][otherSpecPos]=matchA2[i];} }
				else { matchScore1[0]=score2[0]; matchScore2[0]=score2[1]; tmpMatchesPA[0][*pIter].resize(matchB1.size()); for(unsigned int i=0;i<matchB1.size(); i++) { tmpMatchesPA[0][*pIter][i][1-otherSpecPos]=matchB1[i]; tmpMatchesPA[0][*pIter][i][otherSpecPos]=matchB2[i];} }

	            FindMatchPeaksAll(specSet[specToAdd], specSetRev[otherSpec], alignsPA[*pIter].shift1, peakTol, idx1, idx2);
				ScoreOverlap6(specSet[specToAdd], idx1, specSetRev[otherSpec], idx2, alignsPA[*pIter].shift1, peakTol, matchA1, matchA2);
				score1.set(0,0); for(unsigned int i=0; i<matchA1.size();i++) { score1[0]+=specSet[specToAdd][matchA1[i]][1]; score1[1]+=specSetRev[otherSpec][matchA2[i]][1]; }
	            FindMatchPeaksAll(specSet[specToAdd], specSetRev[otherSpec], alignsPA[*pIter].shift2, peakTol, idx1, idx2);
				ScoreOverlap6(specSet[specToAdd], idx1, specSetRev[otherSpec], idx2, alignsPA[*pIter].shift2, peakTol, matchB1, matchB2);
				score2.set(0,0); for(unsigned int i=0; i<matchB1.size();i++) { score2[0]+=specSet[specToAdd][matchB1[i]][1]; score2[1]+=specSetRev[otherSpec][matchB2[i]][1]; }
				if( (score1[0]/specScores[specToAdd]+score1[1]/specScores[otherSpec]) > (score2[0]/specScores[specToAdd]+score2[1]/specScores[otherSpec]) ) 
					{ matchScore1[1]=score1[0]; matchScore2[1]=score1[1]; tmpMatchesPA[1][*pIter].resize(matchA1.size()); for(unsigned int i=0;i<matchA1.size(); i++) { tmpMatchesPA[1][*pIter][i][1-otherSpecPos]=matchA1[i]; tmpMatchesPA[1][*pIter][i][otherSpecPos]=matchA2[i];} }
				else { matchScore1[1]=score2[0]; matchScore2[1]=score2[1]; tmpMatchesPA[1][*pIter].resize(matchB1.size()); for(unsigned int i=0;i<matchB1.size(); i++) { tmpMatchesPA[1][*pIter][i][1-otherSpecPos]=matchB1[i]; tmpMatchesPA[1][*pIter][i][otherSpecPos]=matchB2[i];} }

				flipScores[otherSpec] += (matchScore2[1]-matchScore2[0])/specScores[otherSpec] + (matchScore1[1]-matchScore1[0])/specScores[specToAdd];
                                if(alignRatiosPA) (*alignRatiosPA)[*pIter][0]=min(matchScore1[0]/specScores[specToAdd], matchScore2[0]/specScores[otherSpec]);
                                if(alignRatiosPA) (*alignRatiosPA)[*pIter][1]=min(matchScore1[1]/specScores[specToAdd], matchScore2[1]/specScores[otherSpec]);
				matchPAComputed[*pIter] = true;
			} else {
				matchesPA[*pIter].resize(tmpMatchesPA[flipDir][*pIter].size());
				for(unsigned int i=0;i<matchesPA[*pIter].size();i++) matchesPA[*pIter][i]=tmpMatchesPA[flipDir][*pIter][i];
                                if(alignRatiosPA and flipDir==1) { float tmp=(*alignRatiosPA)[*pIter][0]; (*alignRatiosPA)[*pIter][0]=(*alignRatiosPA)[*pIter][1]; (*alignRatiosPA)[*pIter][1]=tmp; }
			}
		}

		if(toProcess.size()==0) specToAdd=-1; else {
			list<int>::iterator bestNewSpec=toProcess.begin();  pIter=bestNewSpec;  pIter++;
			while(pIter!=toProcess.end()) { if(fabs(flipScores[*pIter])>fabs(flipScores[*bestNewSpec])) bestNewSpec=pIter; pIter++; }
			specToAdd = *bestNewSpec;   toProcess.erase(bestNewSpec);
		}
	}
}
