#include "DeNovoDp.h"
#include "AnnotatedSpectrum.h"
#include "auxfun.h" 

struct dis_pair  {
	dis_pair() : dis(0), left_idx(-1), right_idx(-1) {};
	
	bool operator< (dis_pair& other)
	{
		return dis<other.dis;
	}

	bool operator< (dis_pair& other) const
	{
		return dis<other.dis;
	} 

	bool operator< (const dis_pair& other)
	{
		return dis<other.dis;
	}

	bool operator< (const dis_pair& other) const
	{
		return dis<other.dis;
	}

	mass_t dis;
	
	int left_idx,right_idx;
};




// fills in all cells in the dp_table according to the PrmGraph
void DeNovoDp::fill_dp_table(const PrmGraph *_prm, score_t sym_penalty)
{
	prm = (PrmGraph *)_prm;
	config=prm->config;
	const vector<Node>& nodes = prm->nodes;
	const vector<MultiEdge>& edges = prm->multi_edges;
	const int num_nodes = nodes.size();

	// forbidden window of 1 Daltons
	const mass_t pm_with_19 = prm->pm_with_19;
	const mass_t min_forbidden_sum=pm_with_19 - 1.0078 - config->get_tolerance()*1.5;
	const mass_t max_forbidden_sum=pm_with_19 - 1.0078 + config->get_tolerance()*1.5;
	const mass_t sym_axis = (pm_with_19 - 1.0) *0.5;

	int i,j;
	cells.resize(num_nodes);
	for (i=0; i<num_nodes; i++)
		cells[i].resize(num_nodes);

	vector<bool> skip_ind;
	skip_ind.resize(num_nodes,false);
	for (i=0; i<num_nodes; i++)
		if (nodes[i].in_edge_idxs.size()  == 0 && 
			nodes[i].out_edge_idxs.size() == 0)
			skip_ind[i]=true;

	forbidden_idxs.resize(num_nodes,-1);
	vector<dis_pair> pairs;
	for (i=0; i<num_nodes-1; i++)
		for (j=i+1; j<num_nodes; j++)
		{
			dis_pair p;
			p.dis = nodes[j].mass - nodes[i].mass;
			p.left_idx=i;
			p.right_idx=j;

			if (p.dis< 56.0)
				continue;

			pairs.push_back(p);

			// mark forbidden pairs
			mass_t sum = nodes[j].mass + nodes[i].mass;
			if (sum>min_forbidden_sum && sum<max_forbidden_sum)
			{
				cells[i][j].is_forbidden=1;

//				cout << "F: " << i << " " << j << endl;

				if (forbidden_idxs[j]<0)
				{
					forbidden_idxs[j]=i;
					forbidden_idxs[i]=j;
				}
			}
		}

	sort(pairs.begin(),pairs.end());

	// init diagonal
	for (i=0; i<num_nodes; i++)
		cells[i][i].score = nodes[i].score;

	// fill other cells
	for (i=0; i<pairs.size(); i++)
	{
		const int left_idx = pairs[i].left_idx;
		const int right_idx = pairs[i].right_idx;

		if (sym_axis - nodes[left_idx].mass > nodes[right_idx].mass - sym_axis)
		{
			int j;
			for (j=0; j<nodes[left_idx].out_edge_idxs.size(); j++)
			{
				int e_idx = nodes[left_idx].out_edge_idxs[j];
				score_t s = nodes[left_idx].score + edges[e_idx].max_variant_score + 
							cells[edges[e_idx].c_idx][right_idx].score;
				if (s > cells[left_idx][right_idx].score)
				{
					cells[left_idx][right_idx].score = s;
					cells[left_idx][right_idx].prev_edge_idx = e_idx;
				}
			}
		}
		else  // fill the right side
		{
			int j;
			for (j=0; j<nodes[right_idx].in_edge_idxs.size(); j++)
			{
				int e_idx = nodes[right_idx].in_edge_idxs[j];
				score_t s = nodes[right_idx].score + edges[e_idx].max_variant_score+ 
							cells[left_idx][edges[e_idx].n_idx].score;
				if (s > cells[left_idx][right_idx].score)
				{
					cells[left_idx][right_idx].score = s;
					cells[left_idx][right_idx].prev_edge_idx = e_idx;
				}
			}
		}

		if (cells[left_idx][right_idx].is_forbidden)
			cells[left_idx][right_idx].score -= sym_penalty;
	}


	
}



MultiPath DeNovoDp::get_top_scoring_antisymetric_path( score_t sym_penalty) const
{
	int i,j;
	score_t max_score =NEG_INF;
	int best_left=-1, best_right=-1;

	for (i=0; i<cells.size()-1; i++)
		for (j=i+1; j<cells.size(); j++)
			if (cells[i][j].score > max_score)
			{
				best_left=i;
				best_right=j;
				max_score = cells[i][j].score;
			}

	MultiPath ret_path;
	if (max_score<0)
		return ret_path;

	// collect edges and create path
	vector<int> path_edges;
	int l=best_left, r=best_right;
	while (cells[l][r].prev_edge_idx>=0)
	{
		int e_idx = cells[l][r].prev_edge_idx;
		path_edges.push_back(e_idx);

		if (prm->multi_edges[e_idx].c_idx==r)
		{
			r = prm->multi_edges[e_idx].n_idx;
		}
		else
		{
			l = prm->multi_edges[e_idx].c_idx;
		}
	}
	
	prm->create_path_from_edges(path_edges,ret_path);
	ret_path.path_score = max_score;

	return ret_path;
}




struct edge_idx_set {
	edge_idx_set() : score(NEG_INF), length(0), num_aa(0) {};

	bool operator< (const edge_idx_set& other) const
	{
		return score > other.score;
	}

	edge_idx_set& operator= (const edge_idx_set& e)
	{
		score = e.score;
		length = e.length;
		memcpy(edge_idxs,e.edge_idxs,length*sizeof(int));

		return *this;
	}

	score_t score;
	int length;
	int num_aa; // number of aas used to fill these edges
	int edge_idxs[64]; // max peptide size...
};

/***********************************************************************************
// Returns all the top scoring paths, uses a DFS search with branch and boud pruning
************************************************************************************/
void DeNovoDp::get_top_scoring_antisymetric_paths(vector<MultiPath>& paths, int num_paths, 
												  score_t sym_penalty) const
{
	const vector<Node>& nodes = prm->nodes;
	const vector<MultiEdge>& multi_edges = prm->multi_edges;
	const int num_nodes = nodes.size();
	const int last_heap_pos = num_paths - 1;

	vector<edge_idx_set> heap;
	

	vector<score_t> max_gains; // the maximal score gain that is attainable by continuing
							   // in a path from each node

	vector<score_t> added_scores;  // holds for each depth in the tree, the score that was
								   // added by using the edges in the current path

	vector<int> out_idx_counters; // for each node, what branch are we going down

	vector<bool> ok_start_pos; // should this node be used as start pos

	vector<bool> used_nodes;  // indicators for each node if it was used in the current path

	int i;
	
	// find max score gain from each node
	max_gains.resize(num_nodes,NEG_INF);
	for (i=0; i<num_nodes; i++)
	{
		int j;
		for (j=i+1; j<num_nodes; j++)
			if (cells[i][j].score>max_gains[i])
				max_gains[i]=cells[i][j].score;
	
		max_gains[i] -= cells[i][i].score; // remove the node's score
	}

	added_scores.resize(num_nodes,NEG_INF);
	out_idx_counters.resize(num_nodes,0);
	used_nodes.resize(num_nodes,false);
	ok_start_pos.resize(num_nodes,true);
	for (i=0; i<num_nodes; i++)
		if (nodes[i].score<0)
			ok_start_pos[i]=false;

	heap.resize(num_paths);
	prm->sort_outgoing_edges();

	int start_idx;
	for (start_idx=0; start_idx<num_nodes; start_idx++)
	{
		if (! ok_start_pos[start_idx])
			continue;

		const int num_firs_out_edges = nodes[start_idx].out_edge_idxs.size();
		edge_idx_set current_path;
		int          current_node;
		
		current_node=start_idx;
		current_path.score = nodes[start_idx].score;
		used_nodes[start_idx] = true;

		while (out_idx_counters[start_idx]<num_firs_out_edges)
		{
			if (out_idx_counters[current_node] == nodes[current_node].out_edge_idxs.size())
			{
				// store path if necessary
				if (nodes[current_node].out_edge_idxs.size() == 0 &&
					current_path.score>heap[0].score)
				{
					pop_heap(heap.begin(),heap.end());
					heap[last_heap_pos] = current_path;
					push_heap(heap.begin(),heap.end());
				}

				// backtrack
				out_idx_counters[current_node] =0;
				used_nodes[current_node]=false;
				current_path.length--;
				current_node = multi_edges[current_path.edge_idxs[current_path.length]].n_idx;
				current_path.score -= added_scores[current_path.length];
				continue;
			}


			// check if we can prune this branch
			if (current_path.score + max_gains[current_node] < heap[0].score)
			{
				out_idx_counters[current_node]++;
				continue;
			}


		
			// advance on the edge
			const int edge_idx = nodes[current_node].out_edge_idxs[out_idx_counters[current_node]];
			const MultiEdge& e = multi_edges[edge_idx];

			out_idx_counters[current_node]++;
			current_node = e.c_idx;
			used_nodes[current_node]=true;
			
			added_scores[current_path.length] = e.max_variant_score + nodes[e.c_idx].score;

			// check if forbidden pair is used..
			if (forbidden_idxs[current_node]>=0 && used_nodes[forbidden_idxs[current_node]])
				added_scores[current_path.length] -= sym_penalty; 

			current_path.edge_idxs[current_path.length] = edge_idx;
			current_path.score += added_scores[current_path.length];
			current_path.length++;
			
			// check if the path should be stored at this stage, and if we can mark
			if (added_scores[current_path.length-1]>0 && current_path.score>0 &&
				current_path.score> heap[0].score &&  nodes[current_node].out_edge_idxs.size()>0)
			{
				pop_heap(heap.begin(),heap.end());
				heap[last_heap_pos] = current_path;
				push_heap(heap.begin(),heap.end());
			}
		}

		used_nodes[start_idx] = false;
	}

	sort(heap.begin(),heap.end());

	while (heap.size()>0 && heap[heap.size()-1].score<-40)
		heap.pop_back();

	if (heap.size()<num_paths)
		num_paths = heap.size();

	paths.resize(num_paths);

	for (i=0; i<num_paths; i++)
	{
		int j;
		vector<int> edge_idxs;
		edge_idxs.resize(heap[i].length);
		for (j=0; j<heap[i].length; j++)
			edge_idxs[j]=heap[i].edge_idxs[j];

		prm->create_path_from_edges(edge_idxs, paths[i]);
		paths[i].path_score = heap[i].score;

		paths[i].edge_idxs.resize(heap[i].length);
		for (j=0; j<heap[i].length; j++)
			paths[i].edge_idxs[j]=heap[i].edge_idxs[j];
	}


}




/****************************************************************************
   Computes for each node and number of amino acids, what is the maximal
   score attainable by making X steps forward from that node.
   Finds these values by performing a BFS search from all nodes.
*****************************************************************************/
void DeNovoDp::find_max_gains_per_length(int max_length, 
										 vector< vector< score_t > >& max_gains) const

{
	const int num_nodes = prm->get_num_nodes();
	const vector<Node>& nodes = prm->get_nodes();
	const vector<MultiEdge>& edges = prm->get_multi_edges();
	int n,i;

	max_gains.resize(max_length+1);
	for (i=0; i<max_gains.size(); i++)
		max_gains[i].resize(num_nodes,0);

	for (n=0; n<num_nodes; n++)
		max_gains[0][n]=nodes[n].score;

	for (n=num_nodes-1; n>=0; n--)
	{
		const Node& node = nodes[n];
		const vector<int>& out_idxs = node.out_edge_idxs;
		int j;

		for (j=0; j<out_idxs.size(); j++)
		{
			const MultiEdge& edge = edges[out_idxs[j]];

			// loop over different number of amino acids in edge

			const int& num_aa = edge.num_aa;
			const int& c_idx  = edge.c_idx;
			const int last_aa_len = max_length - num_aa + 1;
			const score_t add_score = node.score + edge.max_variant_score;
			int l;

			for (l=num_aa; l<=last_aa_len; l++)
			{
				const score_t new_score =  add_score + max_gains[l-num_aa][c_idx];
				if (new_score>max_gains[l][n])
					max_gains[l][n]=new_score;
			}
		}

		for (i=1; i<=max_length; i++)
			if (max_gains[i][n]<max_gains[i-1][n])
				max_gains[i][n]=max_gains[i-1][n];
	}

/*	for (n=0; n<num_nodes; n++)
	{
		cout << left << setw(4) << n << "  ";
		for (i=0; i<=max_length; i++)	
			cout << setw(5) << max_gains[i][n] << " ";
		cout << endl;
	} */
}





/*******************************************************************************
 Returns all the top scoring paths, uses a DFS search with branch and bound pruning.
 limits the length of the solution to be between (approximately) supplied bounds.
************************************************************************************/
void DeNovoDp::get_top_scoring_antisymetric_paths_with_length_limits(
					vector<MultiPath>& multi_paths, 
					int required_num_paths, 
					int min_length,
					int max_length,
					score_t sym_penalty) const
{
	const vector<Node>& nodes = prm->nodes;
	const vector<MultiEdge>& multi_edges = prm->multi_edges;
	const int num_nodes = nodes.size();
	const int last_heap_pos = required_num_paths - 1;

	vector<edge_idx_set> heap;
	
	vector< vector<score_t> > max_gains_for_length; // the maximal score attainable from each node
										            // using a given number of amino acids.
												    // length, node_idx

	vector<score_t> added_scores;  // holds for each depth in the tree, the score that was
								   // added by using the edges in the current path          

	vector<int> out_idx_counters; // for each node, what branch are we going down

	vector<bool> used_nodes;  // indicators for each node if it was used in the current path

	int i;
	
	added_scores.resize(num_nodes,NEG_INF);
	out_idx_counters.resize(num_nodes,0);
	used_nodes.resize(num_nodes,false);

	heap.resize(required_num_paths);
	prm->sort_outgoing_edges();

	find_max_gains_per_length(max_length,max_gains_for_length);



	int start_idx;
	for (start_idx=0; start_idx<num_nodes; start_idx++)
	{
		const int num_first_out_edges = nodes[start_idx].out_edge_idxs.size();
		edge_idx_set current_path;
		int          current_node;
		
		current_node=start_idx;
		current_path.score = nodes[start_idx].score;
		used_nodes[start_idx] = true;

		while (1)
		{
			if (out_idx_counters[current_node] >= nodes[current_node].out_edge_idxs.size())
			{
				if (current_node == start_idx)
					break;

				//
			
				// store path if necessary
				if (nodes[current_node].out_edge_idxs.size() == 0 &&
					current_path.num_aa >= min_length &&
					current_path.num_aa <= max_length &&
					current_path.score > heap[0].score)
				{
					pop_heap(heap.begin(),heap.end());
					heap[last_heap_pos] = current_path;
					push_heap(heap.begin(),heap.end());
				}

				// backtrack
				out_idx_counters[current_node] =0;
				used_nodes[current_node]=false;
				current_path.length--;

				const int& path_length     = current_path.length;
				const MultiEdge& back_edge = multi_edges[current_path.edge_idxs[path_length]];
				current_path.num_aa		 -= back_edge.num_aa;
				current_node              = back_edge.n_idx;
				current_path.score       -= added_scores[path_length];
				continue;
			}


			// discard this path if we are using too many edges or 
			// the score will not be able to improve enough

			const int remaining_aas = max_length - current_path.num_aa;

			if (current_path.num_aa > max_length ||
			   (current_path.score + max_gains_for_length[remaining_aas][current_node]
			                         < heap[0].score) )
			{
				out_idx_counters[current_node]++;
				continue;
			} 


		
			// advance on the edge
			const int edge_idx = nodes[current_node].out_edge_idxs[out_idx_counters[current_node]];
			const MultiEdge& e = multi_edges[edge_idx];

	
			out_idx_counters[current_node]++;
			current_node = e.c_idx;
			used_nodes[current_node]=true;
			
			added_scores[current_path.length] = e.max_variant_score + nodes[e.c_idx].score;

			// check if forbidden pair is used..
			if (forbidden_idxs[current_node]>=0 && used_nodes[forbidden_idxs[current_node]])
				added_scores[current_path.length] -= sym_penalty; 

			current_path.edge_idxs[current_path.length] = edge_idx;
			current_path.num_aa += multi_edges[edge_idx].num_aa;
			current_path.score += added_scores[current_path.length];
			current_path.length++;
			
			// 
			score_t heap_score = heap[0].score;
			score_t added_score = added_scores[current_path.length-1];

			// check if the path should be stored at this stage, and if we can mark
			if (added_scores[current_path.length-1]>-5 && 
				nodes[current_node].out_edge_idxs.size()>0 &&
				current_path.num_aa <= max_length &&
				current_path.num_aa >= min_length &&
				current_path.score> heap[0].score )
			{
			
				pop_heap(heap.begin(),heap.end());
				heap[last_heap_pos] = current_path;
				push_heap(heap.begin(),heap.end());
			}
		}

		used_nodes[start_idx] = false;
	}

	sort(heap.begin(),heap.end());

	while (heap.size()>0 && heap[heap.size()-1].score<-40)
		heap.pop_back();

	int actual_num_paths = required_num_paths;
	if (heap.size()<actual_num_paths)
		actual_num_paths = heap.size();

	multi_paths.resize(actual_num_paths);

	for (i=0; i<actual_num_paths; i++)
	{
		int j;
		vector<int> edge_idxs;
		edge_idxs.resize(heap[i].length);
		for (j=0; j<heap[i].length; j++)
			edge_idxs[j]=heap[i].edge_idxs[j];

		prm->create_path_from_edges(edge_idxs, multi_paths[i]);

		multi_paths[i].path_score = heap[i].score;

		multi_paths[i].edge_idxs.resize(heap[i].length);
		for (j=0; j<heap[i].length; j++)
			multi_paths[i].edge_idxs[j]=heap[i].edge_idxs[j];
	}
	

/*	for (i=0; i<multi_paths.size(); i++)
	{
		const vector<int>& amino_acids =  multi_paths[i].peptide.get_amino_acids();
		int k;
		int num_mod=0;

		for (k=0; k<amino_acids.size();k++)
			if (amino_acids[k]>Val)
				num_mod++;

		if (num_mod>max_num_modified)
		{
			multi_paths[i]=multi_paths[multi_paths.size()-1];
			multi_paths.pop_back();
		}
	} */
}













