#include "ScoreDoubleAA.h"
#include "RegularRankModel.h"
#include "auxfun.h"

// selects what combos will be evaluated in the models, and establishes a hierarchy
// by which the models will be created
void ScoreDoubleAA::initialize_model_aa_combos()
{
	model_n_aas.clear();
	model_c_aas.clear();

	model_n_aas.push_back(Pro); model_c_aas.push_back(Gap);
	model_n_aas.push_back(Gap); model_c_aas.push_back(Pro);
	model_n_aas.push_back(Gly); model_c_aas.push_back(Gap);
	model_n_aas.push_back(Gap); model_c_aas.push_back(Gly);
	model_n_aas.push_back(Ser); model_c_aas.push_back(Gap);
	model_n_aas.push_back(Gap); model_c_aas.push_back(Ser);
	model_n_aas.push_back(His); model_c_aas.push_back(Gap);
	model_n_aas.push_back(Gap); model_c_aas.push_back(His);
	model_n_aas.push_back(Gap); model_c_aas.push_back(Gap);
}





float ScoreDoubleAA::calc_variant_prob(const PrmGraph& prm, int me_idx, int* variant_ptr) const
{
	const MultiEdge& edge = prm.get_multi_edge(me_idx);

	if (edge.num_aa != 2)
	{
		cout << "Error: using double aa model for different num of aa!" << endl;
		exit(1);
	}

	ME_Regression_Sample sam;
	float prob = -1;

	if (is_inner_aa_edge(prm,me_idx))
	{
		fill_fval_vector_for_inner_edge(prm,me_idx,variant_ptr,sam);
		int model_idx = get_daa_model_idx(*(variant_ptr+1),*(variant_ptr+2));
		prob=(float)daa_me_models[model_idx]->p_y_given_x(0,sam);
	}
	else
	{
		fill_fval_vector_for_ncd_edge(prm,me_idx,variant_ptr,sam);
		prob=(float)daancd_me_model->p_y_given_x(0,sam);
	}
	
	if (prob<0.01)
		prob=0.01;

	if (prob>0.99)
		prob=0.99;

	return prob;
}



// returns what model idx should be used for this combo of amino acids
// decides by traversing the model_n_aas and model_c_aas (simultanously) and
// finds which idx is the first match
int ScoreDoubleAA::get_daa_model_idx(int n_aa, int c_aa) const
{
	int i;

	for (i=0; i<model_n_aas.size(); i++)
	{
		if ((model_n_aas[i]== Gap || model_n_aas[i] == n_aa) &&
			(model_c_aas[i]== Gap || model_c_aas[i] == c_aa))
			return i;
	}

	cout << "Error: bad model aas, couldn't find match for " << n_aa << " " << c_aa << endl;
	exit(1);
}



bool ScoreDoubleAA::is_inner_aa_edge(const PrmGraph& prm, int me_idx) const
{
	const MultiEdge& edge = prm.get_multi_edge(me_idx);
	const Node& n_node = prm.get_node(edge.n_idx);
	const Node& c_node = prm.get_node(edge.c_idx);

	if (n_node.type == NODE_N_TERM || c_node.type == NODE_C_TERM)
		return false;

	if (n_node.breakage.fragments.size()>0 && c_node.breakage.fragments.size()>0)
		return true;

	return false;
}




void ScoreDoubleAA::fill_fval_vector_for_inner_edge(const PrmGraph& prm, int me_idx, 
									 int* variant_ptr, ME_Regression_Sample& sam ) const
{
	const Config *config = prm.get_config();
	const vector<mass_t>& aa2mass =config->get_aa2mass();

	vector<fval>& fvals = sam.f_vals;

	if (*variant_ptr != 2)
	{
		cout << "Error: using a double aa score for a variant with " << *variant_ptr << " aas." << endl;
		exit(1);
	}

	const int var_aa1 = *(variant_ptr+1);
	const int var_aa2 = *(variant_ptr+2);

	const mass_t exp_mass = aa2mass[var_aa1] + aa2mass[var_aa2];
	const score_t third_score = (prm.get_max_node_score()<0) ? 0 : prm.get_max_node_score() * 0.3333;
	const score_t two_thirds_score = (prm.get_max_node_score()<0) ? 0 : prm.get_max_node_score() * 0.66666;

	fvals.clear();
	const MultiEdge& me = prm.get_multi_edge(me_idx);
	const Node& n_node  = prm.get_node(me.n_idx);
	const Node& c_node  = prm.get_node(me.c_idx);

	fvals.push_back(fval(DAA_CONST,1.0));


	bool n_above_two_thirds = (n_node.score >= two_thirds_score);
	bool c_above_two_thirds = (c_node.score >= two_thirds_score);
	bool n_above_third = (n_node.score >= third_score);
	bool c_above_third = (c_node.score >= third_score);
	bool n_above_zero = (n_node.score >= 0);
	bool c_above_zero = (c_node.score >= 0);
	bool n_above_five = (n_node.score >= 5.0);
	bool c_above_five = (c_node.score >= 5.0);


	if (n_above_two_thirds)
		fvals.push_back(fval(DAA_N_SCORE_ABOVE_TWO_THIRDS,1));

	if (c_above_two_thirds)
		fvals.push_back(fval(DAA_C_SCORE_ABOVE_TWO_THIRDS,1));

	if (n_above_two_thirds && c_above_two_thirds)
		fvals.push_back(fval(DAA_IND_TWO_THIRDS_BOTH_ABOVE,1));

	if (! n_above_two_thirds && n_above_third)
		fvals.push_back(fval(DAA_N_SCORE_ABOVE_THIRD, 1 ));

	if (! c_above_two_thirds && c_above_third)
		fvals.push_back(fval(DAA_C_SCORE_ABOVE_THIRD, 1 ));

	if (! (n_above_two_thirds && c_above_two_thirds) && n_above_third && c_above_third)
		fvals.push_back(fval(DAA_IND_THIRD_BOTH_ABOVE, 1 ));

	if ( ! n_above_third && n_above_zero)
		fvals.push_back(fval(DAA_N_SCORE_ABOVE_ZERO, 1));

	if ( ! c_above_third && c_above_zero)
		fvals.push_back(fval(DAA_C_SCORE_ABOVE_ZERO, 1));

	if (! (n_above_third && c_above_third) && n_above_zero && c_above_zero)
		fvals.push_back(fval(DAA_IND_ZERO_BOTH_ABOVE, 1 ));

	if (n_above_five && c_above_five)
	{
		fvals.push_back(fval(DAA_IND_BOTH_ABOVE_FIVE, 1 ));
	}
	else if (n_above_five)
	{
		fvals.push_back(fval(DAA_IND_N_ABOVE_FIVE, 1 ));
	}
	else if (c_above_five)
	{
		fvals.push_back(fval(DAA_IND_C_ABOVE_FIVE, 1 ));
	}

	float max_score_rank=-1,min_score_rank=-1;

	if (n_node.score > c_node.score)
	{
		max_score_rank = n_node.log_rank;
		min_score_rank = c_node.log_rank;
	}
	else
	{
		max_score_rank = c_node.log_rank;
		min_score_rank = n_node.log_rank;
	}

	fvals.push_back(fval(DAA_MAX_SCORE_RANK,max_score_rank));
	fvals.push_back(fval(DAA_MIN_SCORE_RANK,min_score_rank));


	fvals.push_back(fval(DAA_N_SCORE_RANK, n_node.log_rank)); 
	fvals.push_back(fval(DAA_C_SCORE_RANK, c_node.log_rank));
	fvals.push_back(fval(DAA_SCORE_RANK_SUM, n_node.log_rank+c_node.log_rank));
	fvals.push_back(fval(DAA_SCORE_RANK_DIFF, n_node.log_rank-c_node.log_rank));
	fvals.push_back(fval(DAA_SCORE_RANK_ABS_DIFF, fabs(n_node.log_rank-c_node.log_rank) ));

	const int num_n_frags = n_node.breakage.fragments.size();
	const int num_c_frags = c_node.breakage.fragments.size();

	fvals.push_back(fval(DAA_N_NUM_FRAGS,num_n_frags ));
	fvals.push_back(fval(DAA_C_NUM_FRAGS,num_c_frags ));
	fvals.push_back(fval(DAA_NUM_FRAG_DIFF, num_n_frags - num_c_frags ));
	fvals.push_back(fval(DAA_ABS_NUM_FRAG_DIFF, abs(num_n_frags - num_c_frags) ));



	// connection features
	int max_in_c_idx  = c_node.idx_max_in_score_node;
	int max_out_n_idx = n_node.idx_max_out_score_node;

	if (max_in_c_idx<0 || max_out_n_idx<0)
	{
		cout << "Error: max_in and max_out idxs not filled correctly!" << endl;
		exit(1);
	}

	if (max_in_c_idx == me.n_idx)
	{
		fvals.push_back(fval(DAA_IND_N_IS_MAX_IDX_TO_C,1));
	}
	else
	{
		fvals.push_back(fval(DAA_IND_N_NOT_MAX_IDX_TO_C,1));
		fvals.push_back(fval(DAA_DIFF_N_MAX_IN_C_SCORE_RANKS,
			n_node.log_rank - prm.get_node(max_in_c_idx).log_rank));
	}

	if (max_out_n_idx == me.c_idx)
	{
		fvals.push_back(fval(DAA_IND_C_IS_MAX_IDX_FROM_N,1));
	}
	else
	{
		fvals.push_back(fval(DAA_IND_C_NOT_MAX_IDX_FROM_N,1));
		fvals.push_back(fval(DAA_DIFF_C_MAX_OUT_N_SCORE_RANKS,
			c_node.log_rank - prm.get_node(max_out_n_idx).log_rank));
	}

	bool both_connect_to_max =  (max_in_c_idx == me.n_idx && max_out_n_idx == me.c_idx);
	if (both_connect_to_max)
		fvals.push_back(fval(DAA_IND_BOTH_CONNECT_TO_MAX,1));

	
	// node mass diff

	mass_t node_mass_diff = fabs(c_node.mass - n_node.mass - exp_mass);

	if (node_mass_diff>3.0)
	{
		cout << "Error: large node mass diff: " << node_mass_diff << endl;
		exit(1);
	}

	fvals.push_back(fval(DAA_NODE_MASS_DIFF,node_mass_diff));
	fvals.push_back(fval(DAA_NODE_SQR_MASS_DIFF,node_mass_diff*node_mass_diff));

	// peak mass diff

	mass_t best_mass_diff = 1000.0;
	int i;
	int num_pairs=0;
	mass_t avg_diff=0;
	const vector<BreakageFragment> & n_fragments = n_node.breakage.fragments;
	const vector<BreakageFragment> & c_fragments = c_node.breakage.fragments;
	for (i=0; i<n_fragments.size(); i++)
	{
		const int& frag_type_idx = n_fragments[i].frag_type_idx;
		const int pos = c_node.breakage.get_position_of_frag_idx(frag_type_idx);

		if (pos<0)
			continue;

		num_pairs++;

		const int charge=config->get_fragment(frag_type_idx).charge;

		mass_t mass_diff = fabs(n_fragments[i].mass - c_fragments[pos].mass);
		mass_diff *= charge;
		mass_diff -= exp_mass;
		mass_diff = fabs(mass_diff);

		avg_diff+=mass_diff;

		if (mass_diff<best_mass_diff)
			best_mass_diff = mass_diff;
	}


	if (best_mass_diff<100.0)
	{
		avg_diff /= num_pairs;
		fvals.push_back(fval(DAA_NUM_FRAG_PAIRS,num_pairs));
		fvals.push_back(fval(DAA_AVG_PEAK_DIFF,avg_diff));
		fvals.push_back(fval(DAA_AVG_PEAK_SQR_DIFF,avg_diff*avg_diff));
		fvals.push_back(fval(DAA_BEST_PEAK_DIFF,best_mass_diff));
		fvals.push_back(fval(DAA_BEST_PEAK_SQR_DIFF,best_mass_diff*best_mass_diff));

		fvals.push_back(fval(DAA_AVG_PEAK_DIFF_TIMES_SCORE_RANK_SUM,
			avg_diff * (n_node.log_rank+c_node.log_rank) ));
		fvals.push_back(fval(DAA_AVG_PEAK_DIFF_TIMES_SCORE_ABS_DIFF,
			avg_diff * fabs(n_node.log_rank-c_node.log_rank) ));

		fvals.push_back(fval(DAA_AVG_PEAK_DIFF_DIV_NUM_FRAG_PAIRS,
			avg_diff / num_pairs));

		if (both_connect_to_max)
			fvals.push_back(fval(DAA_IND_BOTH_CONNECT_TO_MAX_TIMES_AVG_DIFF,avg_diff));

	}
	else
		fvals.push_back(fval(DAA_IND_NO_PEAK_DIFF,1.0));


	// features for the alternate route that consists of two amino acid edges (instead
	// of this double edge).

	int      num_alternate_routes=0;
	score_t best_alternate_route_score=NEG_INF;
	int      best_n_edge_idx=-1, best_c_edge_idx=-1, best_node_idx=-1;

	for (i=0; i<n_node.out_edge_idxs.size(); i++)
	{
		const MultiEdge& n_edge = prm.get_multi_edge(n_node.out_edge_idxs[i]);
		if (n_edge.num_aa>1)
			continue;

		int j;
		for (j=0; j<c_node.in_edge_idxs.size(); j++)
		{
			const MultiEdge& c_edge = prm.get_multi_edge(c_node.in_edge_idxs[j]);
			if (c_edge.num_aa>1)
				continue;

			if (n_edge.c_idx == c_edge.n_idx)
			{
				num_alternate_routes++;
				score_t node_score = prm.get_node(n_edge.c_idx).score;
				if (node_score>best_alternate_route_score)
				{
					best_alternate_route_score = node_score;
					best_n_edge_idx = n_node.out_edge_idxs[i];
					best_c_edge_idx = c_node.in_edge_idxs[j];
					best_node_idx = c_edge.n_idx;
				}
			}
		}
	}

	fvals.push_back(fval(DAA_NUM_DOUBLE_EDGE_ROUTES,(float)num_alternate_routes));

	if (best_alternate_route_score>NEG_INF)
	{
		const Node& node = prm.get_node(best_node_idx);
		
		fvals.push_back(fval(DAA_SCORE_RANK_DOUBLE_EDGE_ROUTES,node.log_rank));
		if (node.score>0)
			fvals.push_back(fval(DAA_IND_IND_MAX_SCORE_ALTERNATE_MORE_THAN_ZERO,1.0));

		if (best_node_idx == n_node.idx_max_out_score_node)
			fvals.push_back(fval(DAA_IND_MAX_ALTERNATE_IS_MAX_OUT_N,1.0));

		if (best_node_idx == c_node.idx_max_in_score_node)
			fvals.push_back(fval(DAA_IND_MAX_ALTERNATE_IS_MAX_IN_C,1.0));

		mass_t e1_mass = aa2mass[*(prm.get_multi_edge(best_n_edge_idx).variant_ptrs[0]+1)];
		mass_t e2_mass = aa2mass[*(prm.get_multi_edge(best_c_edge_idx).variant_ptrs[0]+1)];

		mass_t off1 = fabs(e1_mass - node.mass + n_node.mass);
		mass_t off2 = fabs(e2_mass - c_node.mass + node.mass);

		if (off1>5 || off2>5)
		{
			cout << "Error: mismatches in the offset feature of double edges!" << endl;
			exit(1);
		}

		fvals.push_back(fval(DAA_NODE_OFFSETS_ALTERNTE, off1+off2));
		fvals.push_back(fval(DAA_SQR_NODE_ODFFSETS_ALTERNATE, off1*off1 + off2*off2));

	}
	else
		fvals.push_back(fval(DAA_IND_NO_DOUBLE_EDGE_ROUTES,1.0));

	const int problem_aa1[]={Gly,Gly,Glu,Val,Ser,Ala,Asp,Ala,Gly};
	const int problem_aa2[]={Gly,Glu,Gly,Ser,Val,Asp,Ala,Gly,Ala};
	const int num_problem_aas = sizeof(problem_aa1)/sizeof(int);

	for (i=0; i<num_problem_aas; i++)
		if (var_aa1 == problem_aa1[i] && var_aa2 == problem_aa2[i])
		{
			fvals.push_back(fval(DAA_IND_PROBLEMATIC_PAIR_OF_AAS,1.0));
			break;
		}
	
	
	const int charge = prm.get_source_spectrum()->get_charge();

/*	if (charge == 1)
	{
		fvals.push_back(fval(DAA_IND_CHARGE1,1.0));
		fvals.push_back(fval(DAA_CHARGE1_MAX_SCORE_RANK,max_score_rank));
		fvals.push_back(fval(DAA_CHARGE1_MIN_SCORE_RANK,min_score_rank));
	}
	else if (charge == 2)
	{
		fvals.push_back(fval(DAA_IND_CHARGE2,1.0));
		fvals.push_back(fval(DAA_CHARGE2_MAX_SCORE_RANK,max_score_rank));
		fvals.push_back(fval(DAA_CHARGE2_MIN_SCORE_RANK,min_score_rank));
	}
	else if (charge >= 3)
	{
		fvals.push_back(fval(DAA_IND_CHARGE3,1.0));
		fvals.push_back(fval(DAA_CHARGE3_MAX_SCORE_RANK,max_score_rank));
		fvals.push_back(fval(DAA_CHARGE3_MIN_SCORE_RANK,min_score_rank));
	} */
	
	sort(fvals.begin(),fvals.end());
}



void ScoreDoubleAA::fill_fval_vector_for_ncd_edge(const PrmGraph& prm, int me_idx, 
						int* variant_ptr, ME_Regression_Sample& sam) const
{
	const Config *config = prm.get_config();
	const vector<mass_t>& aa2mass =config->get_aa2mass();

	vector<fval>& fvals = sam.f_vals;

	if (*variant_ptr != 2)
	{
		cout << "Error: using a double aa score for a variant with " << *variant_ptr << " aas." << endl;
		exit(1);
	}

	const int var_aa1 = *(variant_ptr+1);
	const int var_aa2 = *(variant_ptr+2);
	const mass_t exp_mass = aa2mass[var_aa1] + aa2mass[var_aa2];
	const MultiEdge&  edge = prm.get_multi_edge(me_idx);
	const Node& n_node = prm.get_node(edge.n_idx);
	const Node& c_node = prm.get_node(edge.c_idx);

	mass_t mass_diff = fabs(exp_mass - c_node.mass + n_node.mass);
	mass_t sqr_diff  = mass_diff * mass_diff;

	fvals.clear();
	fvals.push_back(fval(DAANCD_CONST,1.0));

	if (n_node.breakage.fragments.size() == 0 &&
		c_node.breakage.fragments.size() == 0)
		fvals.push_back(fval(DAANCD_IND_CONNECTS_TO_NODE_WITH_NO_FRAGS,1.0));

	
	if (n_node.type == NODE_N_TERM)
	{
		fvals.push_back(fval(DAANCD_IND_CONNECTS_TO_N_TERMINAL,1.0));
		if (n_node.idx_max_out_score_node == edge.c_idx)
			fvals.push_back(fval(DAANCD_IND_HAS_MAX_SCORE_FROM_N,1.0));

		fvals.push_back(fval(DAANCD_SCORE_FROM_N,c_node.score));
		fvals.push_back(fval(DAANCD_NODE_MASS_DIFF_FROM_N,mass_diff));
		fvals.push_back(fval(DAANCD_NODE_SQR_MASS_DIFF_FROM_N,sqr_diff));
	}

	if (c_node.type == NODE_C_TERM)
	{
		fvals.push_back(fval(DAANCD_IND_CONNECTS_TO_C_TERMINAL,1.0));
		if (c_node.idx_max_in_score_node == edge.n_idx)
			fvals.push_back(fval(DAANCD_IND_HAS_MAX_SCORE_TO_C,1.0));

		fvals.push_back(fval(DAANCD_SCORE_TO_C,n_node.score));
		fvals.push_back(fval(DAANCD_NODE_MASS_DIFF_TO_C,mass_diff));
		fvals.push_back(fval(DAANCD_NODE_SQR_MASS_DIFF_TO_C,sqr_diff));
	}

	if (n_node.type == NODE_DIGEST && c_node.type != NODE_C_TERM)
	{
		fvals.push_back(fval(DAANCD_IND_CONNECTS_TO_DIGEST,1.0));
		if (c_node.idx_max_in_score_node == edge.n_idx)
			fvals.push_back(fval(DAANCD_IND_HAS_MAX_SCORE_TO_DIGEST,1.0));

		fvals.push_back(fval(DAANCD_SCORE_TO_DIGEST,n_node.score));
		fvals.push_back(fval(DAANCD_NODE_MASS_DIFF_TO_DIGEST,mass_diff));
		fvals.push_back(fval(DAANCD_NODE_SQR_MASS_DIFF_TO_DIGEST,sqr_diff));
	}

	if (c_node.type == NODE_DIGEST && n_node.type != NODE_N_TERM)
	{
		fvals.push_back(fval(DAANCD_IND_CONNECTS_TO_DIGEST,1.0));
		if (c_node.idx_max_in_score_node == edge.n_idx)
			fvals.push_back(fval(DAANCD_IND_HAS_MAX_SCORE_TO_DIGEST,1.0));

		fvals.push_back(fval(DAANCD_SCORE_TO_DIGEST,n_node.score));
		fvals.push_back(fval(DAANCD_NODE_MASS_DIFF_TO_DIGEST,mass_diff));
		fvals.push_back(fval(DAANCD_NODE_SQR_MASS_DIFF_TO_DIGEST,sqr_diff));
	}


	// features for the alternate route that consists of two amino acid edges (instead
	// of this double edge).

	int      num_alternate_routes=0;
	score_t best_alternate_route_score=NEG_INF;
	int      best_n_edge_idx=-1, best_c_edge_idx=-1, best_node_idx=-1;

	int i;
	for (i=0; i<n_node.out_edge_idxs.size(); i++)
	{
		const MultiEdge& n_edge = prm.get_multi_edge(n_node.out_edge_idxs[i]);
		if (n_edge.num_aa>1)
			continue;

		int j;
		for (j=0; j<c_node.in_edge_idxs.size(); j++)
		{
			const MultiEdge& c_edge = prm.get_multi_edge(c_node.in_edge_idxs[j]);
			if (c_edge.num_aa>1)
				continue;

			if (n_edge.c_idx == c_edge.n_idx)
			{
				num_alternate_routes++;
				score_t node_score = prm.get_node(n_edge.c_idx).score;
				if (node_score>best_alternate_route_score)
				{
					best_alternate_route_score = node_score;
					best_n_edge_idx = n_node.out_edge_idxs[i];
					best_c_edge_idx = c_node.in_edge_idxs[j];
					best_node_idx = c_edge.n_idx;
				}
			}
		}
	}

	fvals.push_back(fval(DAANCD_NUM_DOUBLE_EDGE_ROUTES,(float)num_alternate_routes));

	if (best_alternate_route_score>NEG_INF)
	{
		const Node& node = prm.get_node(best_node_idx);
		
		fvals.push_back(fval(DAANCD_SCORE_RANK_DOUBLE_EDGE_ROUTES,node.log_rank));
		if (node.score>0)
			fvals.push_back(fval(DAANCD_IND_IND_MAX_SCORE_ALTERNATE_MORE_THAN_ZERO,1.0));

		if (best_node_idx == n_node.idx_max_out_score_node)
			fvals.push_back(fval(DAANCD_IND_MAX_ALTERNATE_IS_MAX_OUT_N,1.0));

		if (best_node_idx == c_node.idx_max_in_score_node)
			fvals.push_back(fval(DAANCD_IND_MAX_ALTERNATE_IS_MAX_IN_C,1.0));

		mass_t e1_mass = aa2mass[*(prm.get_multi_edge(best_n_edge_idx).variant_ptrs[0]+1)];
		mass_t e2_mass = aa2mass[*(prm.get_multi_edge(best_c_edge_idx).variant_ptrs[0]+1)];

		mass_t off1 = fabs(e1_mass - node.mass + n_node.mass);
		mass_t off2 = fabs(e2_mass - c_node.mass + node.mass);

		if (off1>5 || off2>5)
		{
			cout << "Error: mismatches in the offset feature of double edges!" << endl;
			exit(1);
		}

		fvals.push_back(fval(DAANCD_NODE_OFFSETS_ALTERNTE, off1+off2));
		fvals.push_back(fval(DAANCD_SQR_NODE_ODFFSETS_ALTERNATE, off1*off1 + off2*off2));

	}
	else
		fvals.push_back(fval(DAANCD_IND_NO_DOUBLE_EDGE_ROUTES,1.0));





	sort(fvals.begin(),fvals.end());
}




// Trains te models for aas from the files in the mgf list
// if the spectra have no charge, assume they have the supplied charge
// the edges that are used are only ones with two good inner nodes (with frags in each,
// not connected to the terminals
void ScoreDoubleAA::train_daa_models(const FileManager& fm, Model *model, 
									 float all_aa_ratio)
{
	FileSet fs;
	
	Config *config = model->get_config();

	fs.select_all_files(fm);

	initialize_model_aa_combos();
	const int gap_model_idx = get_daa_model_idx(Gap,Gap);
	const int num_daa_models = model_n_aas.size();

	// one set of models and one set of samples for each session aa
	vector< ME_Regression_DataSet >  daa_datasets;

	daa_datasets.resize(num_daa_models);
	daa_me_models.resize(num_daa_models,NULL);
	
	int i;
	for (i=0; i<num_daa_models; i++)
	{
		daa_datasets[i].clear();
		daa_datasets[i].num_classes=2;
	}
	
	vector<int> num_correct_sams;

	num_correct_sams.resize(num_daa_models,0);

	int counter=0;

	while (1)
	{
		Spectrum s;
		PrmGraph prm;
		vector<int> correct_node_idxs, mirror_node_idxs, correct_double_edge_idxs;
		vector<bool> ind_correct_idxs, ind_mirror_idxs;
		Peptide pep;
		int i;

		if (! fs.get_next_spectrum(fm,config,&s) )
			break;

		correct_double_edge_idxs.clear();
		
	//	if (counter++ == 100)
	//		break;

		s.set_org_pm_with_19(s.get_true_mass_with_19());
		if (s.get_charge() ==0)
			continue;
		
		model->init_model_for_scoring_spectrum(&s);

		prm.create_graph_from_spectrum(model,&s,s.get_org_pm_with_19());

//		prm.rank_nodes_according_to_score();

//		prm.set_idxs_max_in_out_for_nodes();

		pep = s.get_peptide();
		prm.get_all_correct_node_idxs(pep ,correct_node_idxs);
		prm.get_all_mirror_node_idxs(pep, mirror_node_idxs);
		ind_correct_idxs.resize(prm.get_num_nodes(),false);
		ind_mirror_idxs.resize(prm.get_num_nodes(),false);
		
		for (i=0; i<correct_node_idxs.size(); i++)
		{
			if (correct_node_idxs[i]>=0)
				ind_correct_idxs[correct_node_idxs[i]]=true;

			if (mirror_node_idxs[i]>=0)
				ind_mirror_idxs[mirror_node_idxs[i]]=true;
		}

		// collect correct edge variants and semi correct (connect to a good node)
		const vector<int> pep_aas = pep.get_amino_acids();

		int a;
		for (a=0; a<pep_aas.size()-1; a++)
		{
			int correct_aa1 = pep_aas[a];
			int correct_aa2 = pep_aas[a+1];

			if (correct_aa1 == Ile)
				correct_aa1 =  Leu;

			if (correct_aa2 == Ile)
				correct_aa2 =  Leu;

			if (correct_node_idxs[a]>=0)
			{
				const Node& n_node = prm.get_node(correct_node_idxs[a]);
			
				if (n_node.breakage.fragments.size()==0)
					continue;

				const vector<int>& out_edges = n_node.out_edge_idxs;
				
				int e;
				int num_out_double_edges=0;
				for (e=0; e<out_edges.size(); e++)
					if (prm.get_multi_edge(out_edges[e]).num_aa == 2)
						num_out_double_edges++;

				if (num_out_double_edges <= 0)
					continue;

				double thresh_for_bad_edge = 0.5 / (1.0 + num_out_double_edges);
				
				// add positive and semi negative nodes for the outgoing edge
				for (e=0; e<out_edges.size(); e++)
				{
					const int me_idx = n_node.out_edge_idxs[e];
					const MultiEdge& me = prm.get_multi_edge(me_idx);
					if (me.num_aa != 2)
						continue;

					if (! is_inner_aa_edge(prm,me_idx))
						continue;

					const Node& c_node = prm.get_node(me.c_idx);
					if (c_node.breakage.fragments.size()==0)
						continue;

					int v;
					for (v=0; v<me.variant_ptrs.size(); v++)
					{
						int *v_ptr = me.variant_ptrs[v];
						int var_aa1 = *(v_ptr+1);
						int var_aa2 = *(v_ptr+2);

						if (*v_ptr == 2 && var_aa1 == correct_aa1 && var_aa2 == correct_aa2)
						{
							// add edge as a good sample
							ME_Regression_Sample sam;

							sam.label=0;
							sam.weight = 3.0;
							fill_fval_vector_for_inner_edge(prm, me_idx, v_ptr, sam);
							
							int model_idx = get_daa_model_idx(var_aa1,var_aa2);

							if (model_idx != gap_model_idx)
							{
							//	if (daa_datasets[model_idx].samples.size() < MAX_AA_SAMPLES)
									daa_datasets[model_idx].add_sample(sam);
									num_correct_sams[model_idx]++;

								// add a reverse sample too if the amino acids are
								// ones that have a strong bias
								if (var_aa1 == Pro || var_aa2 == Pro ||  
									var_aa1 == Gly || var_aa2 == Gly ||
									var_aa1 == His || var_aa2 == His ||
									var_aa1 == Ser || var_aa2 == Ser   )
								{
									ME_Regression_Sample rev_sam;

									int rev_ptr[3];
									rev_ptr[0]=2;     
									rev_ptr[1]=var_aa2;
									rev_ptr[2]=var_aa1;
	
									int rev_model_idx = get_daa_model_idx(var_aa2,var_aa1);

									if (daa_datasets[rev_model_idx].samples.size() < MAX_AA_SAMPLES)
									{
										rev_sam.label=1;
										rev_sam.weight = 1.0;
										fill_fval_vector_for_inner_edge(prm, me_idx, rev_ptr, rev_sam);
										daa_datasets[rev_model_idx].add_sample(rev_sam);
									}
								}
							}
							else
								if (num_correct_sams[gap_model_idx]<15 || 
									(my_random()< all_aa_ratio && daa_datasets[gap_model_idx].samples.size() < MAX_AA_SAMPLES) )
								{
									daa_datasets[gap_model_idx].add_sample(sam);
									num_correct_sams[gap_model_idx]++;
								}

							correct_double_edge_idxs.push_back(me_idx);
							break;
						}
							// otherwise add the node as an incorrect node
						else if (*v_ptr == 2 && ! ind_correct_idxs[me.c_idx])
						{
							if (my_random()<thresh_for_bad_edge)
							{
									// add edge as a good sample
								ME_Regression_Sample sam;

								int model_idx = get_daa_model_idx(var_aa1,var_aa2);

								if (daa_datasets[model_idx].samples.size() > MAX_AA_SAMPLES)
									continue;

								sam.label=1;
								sam.weight=1.0;

								fill_fval_vector_for_inner_edge(prm, me_idx, v_ptr, sam);

								if (model_idx != gap_model_idx)
								{
									daa_datasets[model_idx].add_sample(sam);
								}
								else
									if (my_random()< all_aa_ratio && 
										daa_datasets[gap_model_idx].samples.size() < MAX_AA_SAMPLES)
										daa_datasets[gap_model_idx].add_sample(sam);
							}
						}
					}
				}

				// add negtive semi correct edges for the incoming edges
				const vector<int>& in_edges = n_node.in_edge_idxs;
				int num_in_double_edges=0;
				for (e=0; e<in_edges.size(); e++)
					if (prm.get_multi_edge(in_edges[e]).num_aa == 2)
						num_in_double_edges++;

				if (num_in_double_edges <= 0)
					continue;

				thresh_for_bad_edge = 0.5 / (1.0 + num_in_double_edges);
				
				for (e=0; e<in_edges.size(); e++)
				{
					const int me_idx = in_edges[e];
					const MultiEdge& me = prm.get_multi_edge(me_idx);
					if (me.num_aa != 2)
						continue;

					if (! is_inner_aa_edge(prm,me_idx))
						continue;

					const Node& n_node = prm.get_node(me.n_idx);
					if (n_node.breakage.fragments.size()==0)
						continue;

					int v;
					for (v=0; v<me.variant_ptrs.size(); v++)
					{
						int *v_ptr = me.variant_ptrs[v];
						int var_aa1 = *(v_ptr+1);
						int var_aa2 = *(v_ptr+2);
						if (*v_ptr == 2 && var_aa1 == correct_aa1 && var_aa2 == correct_aa2)
							continue;

							// otherwise add the node as an incorrect node
						if (*v_ptr == 2 && ! ind_correct_idxs[me.n_idx])
						{

							if (my_random()<thresh_for_bad_edge)
							{
								// add edge as a good sample
								ME_Regression_Sample sam;

								int model_idx = get_daa_model_idx(var_aa1,var_aa2);

								if (daa_datasets[model_idx].samples.size() > MAX_AA_SAMPLES)
									continue;

								sam.label=1;
								sam.weight = 1.0;
								fill_fval_vector_for_inner_edge(prm, me_idx, v_ptr, sam);

								if (model_idx != gap_model_idx)
								{
									daa_datasets[model_idx].add_sample(sam);
								}
								else
									if (my_random()< all_aa_ratio &&
										daa_datasets[gap_model_idx].samples.size() <MAX_AA_SAMPLES)
										daa_datasets[gap_model_idx].add_sample(sam);
							}
						}
					}
				}
			}
		}


		// add samples of random edges (make sure they are not correct ones!)
		// will add randomly selected edges upto twice the number of good edges added
		// give these samples a lower weight of 1

		const vector<MultiEdge>& edges = prm.get_multi_edges();
		vector<int> double_edge_idxs;
		
		for (i=0; i<edges.size(); i++)
			if (edges[i].num_aa == 2 && ! ind_correct_idxs[edges[i].n_idx] &&
				! ind_correct_idxs[edges[i].c_idx])
				double_edge_idxs.push_back(i);

		// on average 6 bad double edges will be added per spectraum
		double	rand_thresh = 6.0/(double)double_edge_idxs.size();

		for (i=0; i<double_edge_idxs.size(); i++)
		{
			bool add_anyway = false;
			int me_idx = double_edge_idxs[i];
			const MultiEdge& edge = edges[me_idx];

			if (! is_inner_aa_edge(prm,me_idx))
						continue;

			if (ind_correct_idxs[edge.n_idx] && ind_correct_idxs[edge.c_idx])
				continue;


			// just to make sure, check that this edge was not added as a correct one
			int j;
			for (j=0; j<correct_double_edge_idxs.size(); j++)
				if (correct_double_edge_idxs[j] == me_idx)
					break;

			if (j<correct_double_edge_idxs.size())
				continue;

			if ( my_random() < rand_thresh)
			{
				int v;
				for (v=0; v<edge.variant_ptrs.size(); v++)
				{
					int *v_ptr = edge.variant_ptrs[v];
					if (*v_ptr != 2)
						continue;

					int var_aa1 = *(v_ptr+1);
					int var_aa2 = *(v_ptr+2);

					int model_idx = get_daa_model_idx(var_aa1,var_aa2);

					if (daa_datasets[model_idx].samples.size() > MAX_AA_SAMPLES)
						continue;

					ME_Regression_Sample sam;
					
					sam.label=1;
					sam.weight = 1.0;
					fill_fval_vector_for_inner_edge(prm, me_idx, v_ptr, sam);
					
				

					if (model_idx != gap_model_idx)
					{
						daa_datasets[model_idx].add_sample(sam);
					}
					else
						if (my_random()< all_aa_ratio)
							daa_datasets[gap_model_idx].add_sample(sam);
				}
			}
		}
	}
	
	string file = config->get_resource_dir() + "/" + model->get_model_name() + "_DAA.txt";
	ofstream ofs(file.c_str());

	int num_skipped_models =0;
	for (i=0; i<num_daa_models-1; i++)
		if (num_correct_sams[i]<15)
			num_skipped_models++;

	ofs << num_daa_models - num_skipped_models << endl;

	for (i=0; i<num_daa_models; i++)
	{
	

		daa_datasets[i].tally_samples();
		daa_datasets[i].num_features = DAA_NUM_FIELDS;

		if (daa_datasets[i].num_samples<50 || 
			daa_datasets[i].class_weights[0]/daa_datasets[i].total_weight<0.0001 ||
			daa_datasets[i].class_weights[0]/daa_datasets[i].total_weight>0.9999)
			continue;

		cout << "DS: " << config->get_aa2label()[model_n_aas[i]] << config->get_aa2label()[model_c_aas[i]]<< endl;

		if (num_correct_sams[i]<15 &&  i<gap_model_idx)
		{
			cout << "Not enough positive samples... skipping model..." << endl;
			continue;
		}
		cout << "sample breakdown:" << endl;
	
		if (daa_datasets[i].class_weights[0]/daa_datasets[i].total_weight < 0.02)
			daa_datasets[i].calibrate_class_weights(0.02);

		daa_datasets[i].print_summary();
		daa_datasets[i].print_feature_summary();
		cout << endl;
		
		daa_me_models[i] = new ME_Regression_Model;

		if (! daa_me_models[i]->train_cg(daa_datasets[i],400))
			daa_me_models[i]->set_weigts_for_const_prob(daa_datasets[i].class_weights[0]/daa_datasets[i].total_weight);

		ofs << config->get_aa2label()[model_n_aas[i]] << " " <<
			   config->get_aa2label()[model_c_aas[i]] << endl;
		daa_me_models[i]->write_regression_model(ofs);

		daa_me_models[i]->print_ds_probs(daa_datasets[i]);
		cout << endl;
	}
}


// trains models of edges that connect to the terminals or digest nodes and have one
// node without fragments assigned to it.
// there is just one model for the daancd
void ScoreDoubleAA::train_daancd_model(const FileManager& fm, Model *model)
{
	FileSet fs;
	
	Config *config = model->get_config();

	fs.select_all_files(fm);


	const vector<int> session_aas = config->get_session_aas();
	const vector<int>& n_term_digest_aas = config->get_n_term_digest_aas();
	const vector<int>& c_term_digest_aas = config->get_c_term_digest_aas();



	// one set of models and one set of samples for each session aa
	ME_Regression_DataSet  aa_dataset;

	daancd_me_model = NULL;
	
	aa_dataset.clear();
	aa_dataset.num_classes=2;
	
	int counter=0;
	while (1)
	{
		Spectrum s;
		PrmGraph prm;
		vector<int> correct_node_idxs,  correct_double_edge_idxs;
		vector<int> peptide_aas;
		vector<bool> ind_correct_idxs;
		Peptide pep;
		int i;

		if (! fs.get_next_spectrum(fm,config,&s) )
			break;

		correct_double_edge_idxs.clear();

		if (aa_dataset.samples.size()> MAX_AA_SAMPLES)
			break;
		
	//	if (counter++ == 1000)
	//		break;

		s.set_org_pm_with_19(s.get_true_mass_with_19());
		if (s.get_charge() ==0)
			continue;
		
		model->init_model_for_scoring_spectrum(&s);

		prm.create_graph_from_spectrum(model,&s,s.get_org_pm_with_19());

//		prm.rank_nodes_according_to_score();

//		prm.set_idxs_max_in_out_for_nodes();

		pep = s.get_peptide();
		pep.convert_IL();
		peptide_aas = pep.get_amino_acids();

		prm.get_all_correct_node_idxs(pep ,correct_node_idxs);
		ind_correct_idxs.resize(prm.get_num_nodes(),false);
		
		for (i=0; i<correct_node_idxs.size(); i++)
			if (correct_node_idxs[i]>=0)
				ind_correct_idxs[correct_node_idxs[i]]=true;

		// add samples of random edges (make sure they are not correct ones!)
		// will add randomly selected edges upto twice the number of good edges added
		// give these samples a lower weight of 1

		const vector<MultiEdge>& edges = prm.get_multi_edges();
	
		for (i=0; i<edges.size(); i++)
		{
			const MultiEdge& edge = edges[i];

			if (edge.num_aa != 2)
				continue;

			if (is_inner_aa_edge(prm,i))
				continue;

			const Node& c_node = prm.get_node(edge.c_idx);
			
			int v;
			for (v=0; v<edge.variant_ptrs.size(); v++)
			{
				int *v_ptr = edge.variant_ptrs[v];

				int var_aa1 = *(v_ptr+1);
				int var_aa2 = *(v_ptr+2);
			
				if (ind_correct_idxs[edge.n_idx] && ind_correct_idxs[edge.c_idx])
				{
					// check if this is the correct aa
					int j;
					for (j=0; j<correct_node_idxs.size()-1; j++)
						if (correct_node_idxs[j] == edge.n_idx)
							break;

					if (j== correct_node_idxs.size()-1)
					{
						cout << "Error: mismatch with correct node idx!" << endl;
						exit(1);
					}

					if (peptide_aas[j] != var_aa1 || peptide_aas[j+1] != var_aa2) // make sure this is the correct variant
						continue;

					// add edge as a good sample
					ME_Regression_Sample sam;

					sam.label=0;
					sam.weight = 1.0;
					fill_fval_vector_for_ncd_edge(prm, i, v_ptr, sam);

					aa_dataset.add_sample(sam);
				}
				else
				{
					ME_Regression_Sample sam;

					sam.label=1;
					sam.weight = 1.0;
					fill_fval_vector_for_ncd_edge(prm, i, v_ptr, sam);

					aa_dataset.add_sample(sam);
				}
			}
		}
	}
	
	string file = config->get_resource_dir() + "/" + model->get_model_name() + "_DAANCD.txt";
	ofstream ofs(file.c_str());
	
	aa_dataset.tally_samples();
	aa_dataset.num_features = DAANCD_NUM_FIELDS;

	if (aa_dataset.num_samples<20 || 
		aa_dataset.class_weights[0]/aa_dataset.total_weight<0.00001 ||
		aa_dataset.class_weights[0]/aa_dataset.total_weight>0.99999)
	{
		cout << "Error: insufficient data samples to train DAANCD model!" << endl;
		exit(1);
	}
	
	if (aa_dataset.class_weights[0]/aa_dataset.total_weight < 0.025)
		aa_dataset.calibrate_class_weights(0.025);

	aa_dataset.print_summary();
	aa_dataset.print_feature_summary();
	cout << endl;
		
	daancd_me_model = new ME_Regression_Model;

	if (! daancd_me_model->train_cg(aa_dataset,400))
		daancd_me_model->set_weigts_for_const_prob(aa_dataset.class_weights[0]/aa_dataset.total_weight);

	daancd_me_model->write_regression_model(ofs);

	daancd_me_model->print_ds_probs(aa_dataset);

	ofs.close();
}


// Reads the SAA model (single amino acids in the middle of the peptide)
// and also the SAANCD (single amino acids near terminals or digest nodes missing peaks).
// If an amino acid does not have the a specific model, it is assigned the
// model for Gap (the default model)
bool ScoreDoubleAA::read_model(Config *config, char *model_name)
{
	string daa_path = config->get_resource_dir() + "/" + model_name + "_DAA.txt";

	fstream daa_stream;
	daa_stream.open(daa_path.c_str(),ios::in);

	if (! daa_stream.good())
	{
		cout << "Warning: couldn't open file for reading: " << daa_path << endl;
		return false;
	}

	char buff[64];
	int num_models;

	daa_stream.getline(buff,64);
	num_models = atoi(buff);
	if (num_models<1 || num_models>1000)
	{
		cout << "Error reading model file " << daa_path << endl;
		exit(1);
	}

	daa_me_models.resize(num_models,NULL);
	model_n_aas.clear();
	model_c_aas.clear();

	while (1)
	{
		daa_stream.getline(buff,64);
		if (daa_stream.eof())
			break;

		char lab1[64],lab2[64];
		if (sscanf(buff,"%s %s",lab1,lab2) != 2)
		{
			cout << "Error reading amino acid headers in DAA model!"<< endl;
			exit(1);
		}

		int aa1 = config->get_aa_from_label(string(lab1));
		int aa2 = config->get_aa_from_label(string(lab2));

		if (aa1<0 || aa2<0)
			continue;

		model_n_aas.push_back(aa1);
		model_c_aas.push_back(aa2);

		int model_idx = get_daa_model_idx(aa1,aa2);
		if (model_idx<0)
		{
			cout << "Error model mismatch for aas " << lab1 <<" " << lab2 << endl;
			exit(1);
		}

		daa_me_models[model_idx] = new ME_Regression_Model;

		daa_me_models[model_idx]->read_regression_model(daa_stream);

	}

	int gap_model_idx = get_daa_model_idx(Gap,Gap);
	if (! daa_me_models[gap_model_idx])
	{
		cout << "Error: no double edge default model!" << endl;
		exit(1);
	}


	int i;
	for (i=0; i<model_n_aas.size(); i++)
		if (! daa_me_models[i])
			daa_me_models[i]=daa_me_models[gap_model_idx];

	daa_stream.close();

	// read the saancd model
	string daancd_path = config->get_resource_dir() + "/" + model_name + "_DAANCD.txt";

	fstream daancd_stream;
	daancd_stream.open(daancd_path.c_str(),ios::in);

	if (! daancd_stream.good())
	{
		cout << "Error: couldn't open file for reading: " << daancd_path << endl;
		exit(1);
	}

	
	daancd_me_model = new ME_Regression_Model;

	daancd_me_model->read_regression_model(daancd_stream);

	daancd_stream.close();

	ind_model_was_read = true;

	return true;
//	cout << "Read: " << daa_path << endl;
}












