#include "PrmGraph.h"


// This is not a real heap, the size will usually be small and I am too lazy to implement
// it as a real heap...
struct PathHeap {
public:
	void init(int size) 
	{ 
		heap_size = size; 
		paths.resize(size); 
		scores.clear(); 
		scores.resize(size,NEG_INF);

		min_entry_idx=0;
		min_entry_score = NEG_INF;
	}

	score_t get_min_score() const { return min_entry_score; }

	void find_min_entry() 
	{
		int i;
		min_entry_idx = 0;
		min_entry_score = scores[0];

		for (i=1; i<scores.size(); i++)
			if (scores[i]<min_entry_score)
			{
				min_entry_score = scores[i];
				min_entry_idx = i;
				if (min_entry_score == NEG_INF)
					break;
			}
	}

	int get_num_real_entries() const
	{
		int i,n=0;
		for (i=0; i<scores.size(); i++)
			if (scores[i]>NEG_INF)
				n++;
		return n;
	}

	void add_path(const SeqPath& path)
	{
		if (path.path_score<=min_entry_score)
			return;
		
		paths[min_entry_idx]=path;
		scores[min_entry_idx]=path.path_score;
		find_min_entry();
	}

	void sort_paths() { sort(paths.begin(),paths.end()); }

	vector<SeqPath> get_paths() { return paths; }

private:
	int heap_size;
	int min_entry_idx;
	score_t min_entry_score;

	vector<score_t> scores;
	vector<SeqPath> paths;
};






// add the variant ptr and scores for this combo
void PrmGraph::add_and_score_edge_variants(const AA_combo& aa_combo, MultiEdge& edge)
{
	const vector<int>& aa_positions = config->get_aa_positions();
	const bool reaches_n_term = (nodes[edge.n_idx].type == NODE_N_TERM);
	const bool reaches_c_term = (nodes[edge.c_idx].type == NODE_C_TERM);
	int v;
	
	int *variant_ptr = (int *)config->get_variant_ptr(aa_combo.variant_start_idx);

	for (v=0; v<aa_combo.num_variants; v++)
	{
		int num_aa = *variant_ptr;
		int *aas = variant_ptr+1;

		int i;
		for (i=0; i<num_aa; i++)
			if (aa_positions[*(aas+i)])
				break;

		// need to check that the variant is not violating any of the position restrictions
		// such as +1 -1 positions
		if (i<num_aa)
		{
			if (reaches_n_term)
			{
				int a;
				for (a=0; a<num_aa; a++)
				{
					int aa_idx = aas[a];
					if (aa_positions[aa_idx] != 0 && aa_positions[aa_idx] != a+1)
						break; 
				}
				if (a<num_aa)
				{
					variant_ptr+= num_aa +1;
					continue;
				}
			}
			else // check for +1 positions
			{
				int a;
				for (a=0; a<num_aa; a++)
				{
					int aa_idx = aas[a];
					if (aa_positions[aa_idx] == 1)
						break; 
				}
				if (a<num_aa)
				{
					variant_ptr+= num_aa +1;
					continue;
				}
			}

			if (reaches_c_term)
			{
				int a;
				for (a=0; a<num_aa; a++)
				{
					int aa_idx = aas[a];
					if (aa_positions[aa_idx] != 0 && aa_positions[aa_idx] != a-num_aa )
						break; 
				}
				if (a<num_aa) // found a problem with one of the aa positions
				{
					variant_ptr+= num_aa +1;
					continue;
				}
			}
			else // check for -1 positions
			{
				int a;
				for (a=0; a<num_aa; a++)
				{
					int aa_idx = aas[a];
					if (aa_positions[aa_idx] == -1)
						break; 
				}
				if (a<num_aa)
				{
					variant_ptr+= num_aa +1;
					continue;
				}
			}

			

		}
		
		score_t variant_score = calc_edge_variant_score(edge,num_aa,aas);

		if (variant_score>edge.max_variant_score)
			edge.max_variant_score = variant_score;

		edge.variant_ptrs.push_back(variant_ptr);
		edge.variant_scores.push_back(variant_score);

		variant_ptr+= num_aa +1;
	}
}


// adds the relevant PathPos to the path and adjusts the other non-terminal values 
void SeqPath::add_edge_variant(const MultiEdge& edge, int e_idx, int variant_idx)
{
	PathPos new_pos;

	
	int *variant_ptr = edge.variant_ptrs[variant_idx];
	int num_aa = *variant_ptr++;
	int *aas = variant_ptr;

	if (num_aa != edge.num_aa)
	{
		cout << "Error: edge and variant mixup!" << endl;
		exit(1);
	}
	
	num_aa += edge.num_aa;
	
	new_pos.breakage = edge.n_break;
	new_pos.edge_idx = e_idx;
	new_pos.edge_prob = edge.variant_probs[variant_idx];
	new_pos.mass = edge.n_break->mass;
	new_pos.edge_varaint_score = edge.variant_scores[variant_idx];
	new_pos.node_score = edge.n_break->score;
	new_pos.node_idx = edge.n_idx;
	new_pos.aa = aas[0];

	path_score += new_pos.edge_varaint_score + new_pos.node_score;

	positions.push_back(new_pos);

	if (edge.num_aa == 1)
		return;

	int i;
	for (i=1; i<edge.num_aa; i++)
	{
		PathPos new_pos;

		new_pos.breakage = NULL;
		new_pos.aa = aas[i]; // the rest of the fields are initialized to the default (NULL) values
		positions.push_back(new_pos);
	}
	
}


/************************************************************************

  Expands the single multi path into all possible sequence variants.
  This is done inrementaly by expanding each multi edges

*************************************************************************/
void PrmGraph::expand_multi_path(const MultiPath& multi_path, 
								 vector<SeqPath>& seq_paths,
								 int max_num_paths) const
{
	const vector<AA_combo>& aa_edge_comobos = config->get_aa_edge_combos();
	seq_paths.resize(1);

	int i;
//	for (i=0; i<multi_path.breakages.size(); i++)
//	{
//		cout << i << " " << multi_path.breakages[i] << " " ;
//		if (multi_path.breakages[i])
//			cout << multi_path.breakages[i]->fragments.size() << endl;
//	}

//	prm.print();
//	exit(0);

	seq_paths[0].n_term_aa = multi_path.n_term_aa;
	seq_paths[0].c_term_aa = multi_path.c_term_aa;
	seq_paths[0].n_term_mass = multi_path.n_term_mass;
	seq_paths[0].c_term_mass = multi_path.c_term_mass;
	seq_paths[0].multi_path_rank = multi_path.original_rank;
	seq_paths[0].path_score = 0;
	seq_paths[0].positions.clear();

	int e;
	for (e=0; e<multi_path.edge_idxs.size(); e++)
	{
		const int e_idx = multi_path.edge_idxs[e];

		const MultiEdge& edge = multi_edges[e_idx];

		if (edge.get_num_variants() == 1)
		{
			const int * variant_ptr = edge.variant_ptrs[0];
			const int num_aa = *variant_ptr++;
			const int *aas = variant_ptr;
		
			int i;
			for (i=0; i<seq_paths.size(); i++)
			{
				seq_paths[i].add_edge_variant(edge,e_idx,0);
				
			}
		}
		else
		{
			vector<SeqPath> old_paths = seq_paths;
			seq_paths.resize(seq_paths.size()*edge.get_num_variants());

			int v;
			int idx=0;
			for (v=0; v<edge.get_num_variants(); v++)
			{
				int i;
				for (i=0; i<old_paths.size(); i++)
					seq_paths[idx++]=old_paths[i];
			}

			
			idx=0;
			for (v=0; v<edge.get_num_variants(); v++)
			{
				const int * variant_ptr =  edge.variant_ptrs[v];
				const int num_aa = *variant_ptr++;
				const int *aas = variant_ptr;

				int i;
				for (i=0; i<old_paths.size(); i++)
					seq_paths[idx++].add_edge_variant(edge,e_idx,v);
			}	
		}

		if (max_num_paths>0 && seq_paths.size()>max_num_paths)
		{
			sort(seq_paths.begin(),seq_paths.end());
			while (seq_paths.size()>max_num_paths)
				seq_paths.pop_back();
		}
	}

	const MultiEdge& last_edge = multi_edges[multi_path.edge_idxs[multi_path.edge_idxs.size()-1]];
	PathPos last_pos;

	last_pos.breakage = last_edge.c_break;
	last_pos.edge_idx =-1;
	last_pos.mass = last_edge.c_break->mass;
	last_pos.node_idx = last_edge.c_idx;
	last_pos.node_score = last_edge.c_break->score;


	for (i=0; i<seq_paths.size(); i++)
	{
		seq_paths[i].path_score += last_pos.node_score;
		seq_paths[i].positions.push_back(last_pos);
		seq_paths[i].make_seq_str(config);
	}
}


void PrmGraph::expand_all_multi_paths(const vector<MultiPath>& multi_paths, 
							 vector<SeqPath>& paths, int max_num_paths) const
{
	// expand all variants
	int i;
	PathHeap path_heap;

	path_heap.init(max_num_paths);
	paths.clear();

	for (i=0; i<multi_paths.size(); i++ )
	{
		vector<SeqPath> variants;

		if (multi_paths[i].path_score<path_heap.get_min_score())
			continue;

		expand_multi_path(multi_paths[i],variants,max_num_paths);

		sort(variants.begin(),variants.end());
		int j;
		for (j=0; j<variants.size(); j++)
		{
			variants[j].multi_path_rank = i;

			// check that the peptide has correct mass if it spans the entire graph

			int n_idx = variants[j].positions[0].node_idx;
			int c_idx = variants[j].positions[ variants[j].positions.size()-1].node_idx;

			if (nodes[n_idx].type == NODE_N_TERM &&	nodes[c_idx].type == NODE_C_TERM) 
			{
				if (! config->get_need_to_estimate_pm() )
				{
					const vector<mass_t>& aa2mass = config->get_aa2mass();
					vector<int> aas;
					variants[j].get_amino_acids(aas);
					mass_t pep_mass = 0;
					int a;
					for (a=0; a<aas.size(); a++)
						pep_mass += aa2mass[aas[a]];

					pep_mass+=19.0183;

					if (fabs(pep_mass-source_spectrum->get_org_pm_with_19())>config->get_pm_tolerance())
					{
					//	cout << "Rejected: " << variants[j].seq_str << endl;
					//	cout << "aa_mass: " << fixed << setprecision(3) << pep_mass << " spec_mass: " << source_spectrum->get_org_pm_with_19();
					//	cout << " (" << pep_mass-source_spectrum->get_org_pm_with_19() << ")" << endl;

						continue;
					}

					// add bonus score since the amino acids match the mass
				//	variants[j].path_score += BONUS_FOR_COMPLETE_PEPTIDE;
				//	cout << "Added bonus!: " << fixed << setprecision(3) <<  variants[j].path_score << endl;
				}
			}

			// check that there are no PTMs that are specific to the +1, -1 positions

		
			path_heap.add_path(variants[j]);
		
		}
	//	cout << path_heap.get_num_real_entries() << " ";
	}
//	cout << endl;

	paths = path_heap.get_paths();
	sort(paths.begin(),paths.end());

	while (paths.size()>0)
	{
		int idx = paths.size()-1;
		if (paths[idx].path_score<= NEG_INF ||
			paths[idx].get_num_aa() < 1)
		{
			paths.pop_back();
		}
		else
			break;
	}
}



bool MultiPath::check_if_correct(const Peptide& p, Config *config) const
{
	const mass_t tolerance = config->get_tolerance() * 1.25;
	vector<mass_t> break_masses;
	int idx=0;
	int i;

	p.calc_expected_breakage_masses(config,break_masses);

	for (i=0; i<breakages.size(); i++)
	{
		const mass_t& mass = breakages[i]->mass;
		const mass_t max_mass = mass + tolerance;
		const mass_t min_mass = mass - tolerance;

		while (idx < break_masses.size() && break_masses[idx] < min_mass)
			idx++;

		if (break_masses[idx]>max_mass)
			return false;
	}

	return true;
}



int  MultiPath::get_num_correct_aas(const PrmGraph& prm, const Peptide& p, Config *config) const
{
	const mass_t tolerance = config->get_tolerance() * 1.25;
	vector<mass_t> break_masses;
	int idx=0;
	int num_correct=0;
	int i;

	p.calc_expected_breakage_masses(config,break_masses);

	for (i=0; i<breakages.size(); i++)
	{
		const mass_t& mass = breakages[i]->mass;
		const mass_t max_mass = mass + tolerance;
		const mass_t min_mass = mass - tolerance;

		while (idx < break_masses.size() && break_masses[idx] < min_mass)
			idx++;

		if (break_masses[idx]>max_mass)
			continue;
		
		if (idx<breakages.size()-1 && edge_idxs[idx]>=0)
			num_correct += prm.get_multi_edge(edge_idxs[idx]).num_aa;
	}
	return num_correct;
}


int  MultiPath::get_num_aas() const
{
	return (breakages.size()-1);
}


// returns the number of b/y ions
int SeqPath::get_num_frags(const vector<int>& frag_idxs) const
{
	int num_frags=0;
	int i;
	for (i=0; i<positions.size(); i++)
	{
		Breakage *bb = positions[i].breakage;

		if (positions[i].breakage && positions[i].breakage->fragments.size()>0)
		{
			int j;
			for (j=0; j<frag_idxs.size(); j++)
			{
				int k;
				for (k=0; k<positions[i].breakage->fragments.size(); k++)
				{
					if (positions[i].breakage->fragments[k].frag_type_idx == frag_idxs[j])
						num_frags++;
				}
			}
		}
	}
	return num_frags;
}




int SeqPath::get_num_correct_aas(Peptide& pep, Config *config) const
{
	const vector<mass_t>& aa2mass = config->get_aa2mass();
	const vector<int>& pep_aas = pep.get_amino_acids();

	int num_correct=0;
	int i;

	vector<mass_t> pep_masses;
	vector<int> path_aas;
	
	get_amino_acids(path_aas);

	pep_masses.resize(pep_aas.size(),0);
	for (i=1; i<pep_aas.size(); i++)
		pep_masses[i]=pep_masses[i-1]+aa2mass[pep_aas[i-1]];

	mass_t path_mass = n_term_mass;
	for (i=0; i<path_aas.size(); i++)
	{
		int j;
		for (j=0; j<pep_aas.size(); j++)
		{
			if (fabs(pep_masses[j]-path_mass)<1.0 && pep_aas[j] == path_aas[i])
			{
				num_correct++;
				break;
			}
		}

		path_mass += aa2mass[path_aas[i]];
	}

	return num_correct;
}





void SeqPath::parse_path_to_smaller_ones(Config *config,int min_length, int max_length, 
										 vector<SeqPath>& new_paths) const
{
	new_paths.clear();

	int i;

	const int num_seq_positions = positions.size();
	if (num_seq_positions<2)
		return;

	for (i=0; i<positions.size()-min_length; i++)
	{
		if (positions[i].aa>=0 &&
			positions[i].edge_idx>=0)
		{
			int j=0;
			SeqPath path;

			path.positions.clear();
			if (i==0)
				path.n_term_aa = n_term_aa;

			path.n_term_mass = positions[i].mass;
			path.path_score = 0;
			path.multi_path_rank = multi_path_rank;
			
			int start_pos=i;
			int pos=i;
			while (pos<num_seq_positions )
			{
				path.positions.push_back(positions[pos]);
				path.path_score += positions[pos].node_score;
				path.path_score += positions[pos].edge_varaint_score;
				pos++;
			
				while (pos<num_seq_positions  && positions[pos].node_idx<0)
					path.positions.push_back(positions[pos++]);
				
				int length = pos-start_pos;
				if (pos<num_seq_positions && length>=min_length && length<=max_length)
				{
					SeqPath add_path = path;
					add_path.path_score+=positions[pos].node_score;
					add_path.c_term_mass = positions[pos].mass;
					if (pos==positions.size())
						add_path.c_term_aa= c_term_aa;

					add_path.positions.push_back(positions[pos]);

					new_paths.push_back(add_path);
				}
			}	
		}
	}

	for (i=0; i<new_paths.size(); i++)
		new_paths[i].make_seq_str(config);
}



bool SeqPath::check_if_correct(const string& str, Config *config) const
{
	const vector<mass_t>& aa2mass = config->get_aa2mass();
	const char *path_str = seq_str.c_str();
	const char *corr_str = str.c_str();

	int len_path_str = strlen(path_str);
	int len_corr_str = strlen(corr_str);

	if (len_path_str>len_corr_str)
		return false;

	
	int i;
	for (i=0; i<=len_corr_str-len_path_str; i++)
	{
		int j;
		bool correct_seq = true;
		for (j=0; j<len_path_str; j++)
			if (! (path_str[j] == corr_str[i+j] ||
				  (path_str[j] == 'I' && corr_str[i+j]== 'L') ||
				  (path_str[j] == 'L' && corr_str[i+j]== 'I') ) )
			{
				correct_seq = false;
				break;
			}



		if (correct_seq)
		{
			// check prefix mass
			Peptide pep;
			pep.parse_from_string(config,corr_str);
			const vector<int>& aas= pep.get_amino_acids();
			mass_t mass=0;
			int j;

			if (n_term_mass == 0 && i==0)
				return true;

			for (j=0; j<aas.size(); j++)
			{
				mass+=aa2mass[aas[j]];
				if (fabs(mass-this->n_term_mass)<6)
					return true;

				if (mass>n_term_mass)
					break;
			}
		}
	}

	
	
	return false;
}





void SeqPath::make_seq_str(Config *config)
{
	const vector<string>& aa2label = config->get_aa2label();
	int i;

	seq_str = "";

	if (n_term_aa>N_TERM)
		seq_str += aa2label[n_term_aa] ;
	
	if (positions.size()>0)
		for (i=0; i<positions.size()-1; i++)
			seq_str += aa2label[positions[i].aa];

	if (c_term_aa>C_TERM)
		seq_str +=  aa2label[c_term_aa];

}

void MultiPath::print(Config *config, ostream& os) const
{
	os << "MultiPath: " << n_term_mass << " - " << c_term_mass << 
		   " score: " << path_score << "  ";
	int i;

	cout << " Nodes:";
	for (i=0; i<node_idxs.size(); i++)
		cout << " " << node_idxs[i];
	cout << endl;
}


void SeqPath::print(ostream& os) const
{
	os << setprecision(5);
	os << n_term_mass << " " << seq_str << " " << c_term_mass << " (s: " << 
		this->path_score << ")" << endl;
}


void SeqPath::print_with_probs(ostream& os) const
{
	os << setprecision(5);
	os << n_term_mass << " " << seq_str << " (s: " << 
		this->path_score << ")";

	int i;
	os << setprecision(2);
	for (i=0; i<this->positions.size(); i++)
	{
		os << " " << positions[i].edge_prob;
	}
	os << endl;
}

void SeqPath::print_full(Config *config, ostream &os) const
{
	const vector<string>& aa2label = config->get_aa2label();
	os << n_term_mass << " " << seq_str << " " << c_term_mass << " (s: " << 
		this->path_score << ")" << endl;
	int i;

	for (i=0; i<positions.size()-1; i++)
	{
		cout << left << setw(3) << i;
		cout << setw(5) << left << aa2label[positions[i].aa] << " " <<
			positions[i].node_idx << " " << positions[i].mass << endl;
	}
		
}
