#include "ScoreSingleAA.h"
#include "RegularRankModel.h"
#include "auxfun.h"


/*
	SAA_N_SCORE,       SAA_C_SCORE,     SAA_SCORE_DIFF,
	SAA_N_NUM_FRAGS,   SAA_C_NUM_FRAGS, SAA_NUM_FRAG_DIFF,
	SAA_LOG_NC_INTEN_RATIO,      // max val +- 5 
	SAA_NODE_MASS_DIFF,     SAA_NODE_SQR_MASS_DIFF,
	SAA_BEST_PEAK_DIFF,     SAA_BEST_PEAK_SQR_DIFF,     SAA_IND_NO_PEAK_DIFF,

	SAA_IND_N_CONNECT_TO_A, SAA_N_SCORE_A, SAA_N_PEAK_DIFF_A,
	SAA_IND_N_CONNECT_TO_G, SAA_N_SCORE_G, SAA_N_PEAK_DIFF_G,
	SAA_IND_N_CONNECT_TO_D, SAA_N_SCORE_D, SAA_N_PEAK_DIFF_D,
	SAA_IND_N_CONNECT_TO_E, SAA_N_SCORE_E, SAA_N_PEAK_DIFF_E,
	SAA_IND_N_CONNECT_TO_V, SAA_N_SCORE_V, SAA_N_PEAK_DIFF_V,
	SAA_IND_N_CONNECT_TO_S, SAA_N_SCORE_S, SAA_N_PEAK_DIFF_S,
*/


float ScoreSingleAA::calc_variant_prob(const PrmGraph& prm, int me_idx, int* variant_ptr) const
{
	const MultiEdge& edge = prm.get_multi_edge(me_idx);

	if (edge.num_aa != 1)
	{
		cout << "Error: using single aa model for different num of aa!" << endl;
		exit(1);
	}

	ME_Regression_Sample sam;
	float prob = -1;

	if (is_inner_aa_edge(prm,me_idx))
	{
		fill_fval_vector_for_inner_edge(prm,me_idx,variant_ptr,sam);
		int model_idx = get_saa_model_idx(*(variant_ptr+1));
		prob=(float)saa_me_models[model_idx]->p_y_given_x(0,sam);
	}
	else
	{
		fill_fval_vector_for_ncd_edge(prm,me_idx,variant_ptr,sam);
		int model_idx = get_saancd_model_idx(*(variant_ptr+1));
		prob=(float)saancd_me_models[model_idx]->p_y_given_x(0,sam);
	}
	
	if (prob<0.01)
		prob=0.01;

	if (prob>0.99)
		prob=0.99;

	return prob;

}


bool ScoreSingleAA::is_inner_aa_edge(const PrmGraph& prm, int me_idx) const
{
	const MultiEdge& edge = prm.get_multi_edge(me_idx);
	const Node& n_node = prm.get_node(edge.n_idx);
	const Node& c_node = prm.get_node(edge.c_idx);

	if (n_node.type == NODE_N_TERM || c_node.type == NODE_C_TERM)
		return false;

	if (n_node.breakage.fragments.size()>0 && c_node.breakage.fragments.size()>0)
		return true;

	return false;
}



// returns the first model that is appropriate for scoring the given aa
int ScoreSingleAA::get_saa_model_idx(int aa) const
{
	int i;
	for (i=0; i<saa_model_aas.size(); i++)
		if (saa_model_aas[i]==Gap || saa_model_aas[i] == aa)
			return i;

	return -1;
}

int ScoreSingleAA::get_saancd_model_idx(int aa) const
{
	int i;
	for (i=0; i<saancd_model_aas.size(); i++)
		if (saancd_model_aas[i]==Gap || saancd_model_aas[i] == aa)
			return i;

	return -1;
}



/*************************************************************************

Fills the features for a single aa edge where it is not an internal edge
(i.e., it is not connected to the N-,C-terminals or to a digest node.

**************************************************************************/
void ScoreSingleAA::fill_fval_vector_for_inner_edge(const PrmGraph& prm, int me_idx, 
									 int* variant_ptr, ME_Regression_Sample& sam ) const
{
	const Config *config = prm.get_config();
	const vector<mass_t>& aa2mass =config->get_aa2mass();

	vector<fval>& fvals = sam.f_vals;

	if (*variant_ptr != 1)
	{
		cout << "Error: using a single aa score for a variant with " << *variant_ptr << " aas." << endl;
		exit(1);
	}

	const int var_aa = *(variant_ptr+1); // this is the variant aa (the amino acid of the edge)

	const mass_t exp_mass = aa2mass[var_aa];
	const score_t third_score = (prm.get_max_node_score()<0) ? 0 : prm.get_max_node_score() * 0.3333;
	const score_t two_thirds_score = (prm.get_max_node_score()<0) ? 0 : prm.get_max_node_score() * 0.66666;

	fvals.clear();
	const MultiEdge& me = prm.get_multi_edge(me_idx);
	const Node& n_node  = prm.get_node(me.n_idx);
	const Node& c_node  = prm.get_node(me.c_idx);

	fvals.push_back(fval(SAA_CONST,1.0));


	bool n_above_two_thirds = (n_node.score >= two_thirds_score);
	bool c_above_two_thirds = (c_node.score >= two_thirds_score);
	bool n_above_third = (n_node.score >= third_score);
	bool c_above_third = (c_node.score >= third_score);
	bool n_above_zero = (n_node.score >= 0);
	bool c_above_zero = (c_node.score >= 0);


	if (n_above_two_thirds)
		fvals.push_back(fval(SAA_N_SCORE_ABOVE_TWO_THIRDS,1));

	if (c_above_two_thirds)
		fvals.push_back(fval(SAA_C_SCORE_ABOVE_TWO_THIRDS,1));

	if (n_above_two_thirds && c_above_two_thirds)
		fvals.push_back(fval(SAA_IND_TWO_THIRDS_BOTH_ABOVE,1));

	if (! n_above_two_thirds && n_above_third)
		fvals.push_back(fval(SAA_N_SCORE_ABOVE_THIRD, 1 ));

	if (! c_above_two_thirds && c_above_third)
		fvals.push_back(fval(SAA_C_SCORE_ABOVE_THIRD, 1 ));

	if (! (n_above_two_thirds && c_above_two_thirds) && n_above_third && c_above_third)
		fvals.push_back(fval(SAA_IND_THIRD_BOTH_ABOVE, 1 ));

	if ( ! n_above_third && n_above_zero)
		fvals.push_back(fval(SAA_N_SCORE_ABOVE_ZERO, 1));

	if ( ! c_above_third && c_above_zero)
		fvals.push_back(fval(SAA_C_SCORE_ABOVE_ZERO, 1));

	if (! (n_above_third && c_above_third) && n_above_zero && c_above_zero)
		fvals.push_back(fval(SAA_IND_ZERO_BOTH_ABOVE, 1 ));

	float max_score_rank=-1,min_score_rank=-1;
	if (n_node.score > c_node.score)
	{
		max_score_rank = n_node.log_rank;
		min_score_rank = c_node.log_rank;
	}
	else
	{
		max_score_rank = c_node.log_rank;
		min_score_rank = n_node.log_rank;
	}

	fvals.push_back(fval(SAA_MAX_SCORE_RANK,max_score_rank));
	fvals.push_back(fval(SAA_MIN_SCORE_RANK,min_score_rank));

	fvals.push_back(fval(SAA_N_SCORE_RANK, n_node.log_rank)); 
	fvals.push_back(fval(SAA_C_SCORE_RANK, c_node.log_rank));
	fvals.push_back(fval(SAA_SCORE_RANK_SUM, n_node.log_rank+c_node.log_rank));
	fvals.push_back(fval(SAA_SCORE_RANK_DIFF, n_node.log_rank-c_node.log_rank));
	fvals.push_back(fval(SAA_SCORE_RANK_ABS_DIFF, fabs(n_node.log_rank-c_node.log_rank) ));

	const int num_n_frags = n_node.breakage.fragments.size();
	const int num_c_frags = c_node.breakage.fragments.size();


	fvals.push_back(fval(SAA_N_NUM_FRAGS,num_n_frags ));
	fvals.push_back(fval(SAA_C_NUM_FRAGS,num_c_frags ));
	fvals.push_back(fval(SAA_NUM_FRAG_DIFF, num_n_frags - num_c_frags ));
	fvals.push_back(fval(SAA_ABS_NUM_FRAG_DIFF, abs(num_n_frags - num_c_frags) ));

	// intensity ratio

	const intensity_t& n_inten = n_node.breakage.total_intensity;
	const intensity_t& c_inten = c_node.breakage.total_intensity;

	if (n_inten<=0 && c_inten<=0)
	{
		
	}
	else if (c_inten<=0)
	{
		fvals.push_back(fval(SAA_IND_N_STRONGER_INTEN,1));
		fvals.push_back(fval(SAA_N_STRONGER_LOG_NC_INTEN_RATIO,5.0));
	}
	else if (n_inten<=0)
	{
		fvals.push_back(fval(SAA_IND_C_STRONGER_INTEN,1));
		fvals.push_back(fval(SAA_C_STRONGER_LOG_NC_INTEN_RATIO,5.0));
	}
	else
	{
		if (n_inten>=c_inten)
		{
			fvals.push_back(fval(SAA_IND_N_STRONGER_INTEN,1));
			fvals.push_back(fval(SAA_N_STRONGER_LOG_NC_INTEN_RATIO,log(n_inten/c_inten) ));
		}
		else
		{
			fvals.push_back(fval(SAA_IND_C_STRONGER_INTEN,1));
			fvals.push_back(fval(SAA_C_STRONGER_LOG_NC_INTEN_RATIO,log(c_inten/n_inten) ));
		}
	}

	// connection features
	int max_in_c_idx  = c_node.idx_max_in_score_node;
	int max_out_n_idx = n_node.idx_max_out_score_node;

	if (max_in_c_idx<0 || max_out_n_idx<0)
	{
		cout << "Error: max_in and max_out idxs not filled correctly!" << endl;
		exit(1);
	}

	if (max_in_c_idx == me.n_idx)
	{
		fvals.push_back(fval(SAA_IND_N_IS_MAX_IDX_TO_C,1));
	}
	else
	{
		fvals.push_back(fval(SAA_IND_N_NOT_MAX_IDX_TO_C,1));
		fvals.push_back(fval(SAA_DIFF_N_MAX_IN_C_SCORE_RANKS,
			n_node.log_rank - prm.get_node(max_in_c_idx).log_rank));
	}

	if (max_out_n_idx == me.c_idx)
	{
		fvals.push_back(fval(SAA_IND_C_IS_MAX_IDX_FROM_N,1));
	}
	else
	{
		fvals.push_back(fval(SAA_IND_C_NOT_MAX_IDX_FROM_N,1));
		fvals.push_back(fval(SAA_DIFF_C_MAX_OUT_N_SCORE_RANKS,
			c_node.log_rank - prm.get_node(max_out_n_idx).log_rank));
	}

	bool both_connect_to_max =  (max_in_c_idx == me.n_idx && max_out_n_idx == me.c_idx);
	if (both_connect_to_max)
		fvals.push_back(fval(SAA_IND_BOTH_CONNECT_TO_MAX,1));

	
	// node mass diff

	mass_t node_mass_diff = fabs(c_node.mass - n_node.mass - exp_mass);

	if (node_mass_diff>3.0)
	{
		cout << "Error: large node mass diff: " << node_mass_diff << endl;
		exit(1);
	}

	fvals.push_back(fval(SAA_NODE_MASS_DIFF,node_mass_diff));
	fvals.push_back(fval(SAA_NODE_SQR_MASS_DIFF,node_mass_diff*node_mass_diff));

	// peak mass diff

	mass_t best_mass_diff = 1000.0;
	int i;
	int num_pairs=0;
	mass_t avg_diff=0;
	const vector<BreakageFragment> & n_fragments = n_node.breakage.fragments;
	const vector<BreakageFragment> & c_fragments = c_node.breakage.fragments;
	for (i=0; i<n_fragments.size(); i++)
	{
		const int& frag_type_idx = n_fragments[i].frag_type_idx;
		const int pos = c_node.breakage.get_position_of_frag_idx(frag_type_idx);

		if (pos<0)
			continue;

		num_pairs++;

		const int charge=config->get_fragment(frag_type_idx).charge;

		mass_t mass_diff = fabs(n_fragments[i].mass - c_fragments[pos].mass);
		mass_diff *= charge;
		mass_diff -= exp_mass;
		mass_diff = fabs(mass_diff);

		avg_diff+=mass_diff;

		if (mass_diff<best_mass_diff)
			best_mass_diff = mass_diff;
	}


	if (best_mass_diff<100.0)
	{
		avg_diff /= num_pairs;
		fvals.push_back(fval(SAA_NUM_FRAG_PAIRS,num_pairs));
		fvals.push_back(fval(SAA_AVG_PEAK_DIFF,avg_diff));
		fvals.push_back(fval(SAA_AVG_PEAK_SQR_DIFF,avg_diff*avg_diff));
		fvals.push_back(fval(SAA_BEST_PEAK_DIFF,best_mass_diff));
		fvals.push_back(fval(SAA_BEST_PEAK_SQR_DIFF,best_mass_diff*best_mass_diff));

		fvals.push_back(fval(SAA_AVG_PEAK_DIFF_TIMES_SCORE_RANK_SUM,
			avg_diff * (n_node.log_rank+c_node.log_rank) ));
		fvals.push_back(fval(SAA_AVG_PEAK_DIFF_TIMES_SCORE_ABS_DIFF,
			avg_diff * fabs(n_node.log_rank-c_node.log_rank) ));

		fvals.push_back(fval(SAA_AVG_PEAK_DIFF_DIV_NUM_FRAG_PAIRS,
			avg_diff / num_pairs));

		if (both_connect_to_max)
			fvals.push_back(fval(SAA_IND_BOTH_CONNECT_TO_MAX_TIMES_AVG_DIFF,avg_diff));

	}
	else
		fvals.push_back(fval(SAA_IND_NO_PEAK_DIFF,1.0));


	const int charge = prm.get_source_spectrum()->get_charge();

/*	if (charge == 1)
	{
		fvals.push_back(fval(SAA_IND_CHARGE1,1.0));
		fvals.push_back(fval(SAA_CHARGE1_MAX_SCORE_RANK,max_score_rank));
		fvals.push_back(fval(SAA_CHARGE1_MIN_SCORE_RANK,min_score_rank));
	}
	else if (charge == 2)
	{
		fvals.push_back(fval(SAA_IND_CHARGE2,1.0));
		fvals.push_back(fval(SAA_CHARGE2_MAX_SCORE_RANK,max_score_rank));
		fvals.push_back(fval(SAA_CHARGE2_MIN_SCORE_RANK,min_score_rank));
	}
	else
	{
		fvals.push_back(fval(SAA_IND_CHARGE3,1.0));
		fvals.push_back(fval(SAA_CHARGE3_MAX_SCORE_RANK,max_score_rank));
		fvals.push_back(fval(SAA_CHARGE3_MIN_SCORE_RANK,min_score_rank));
	} */


	sort(fvals.begin(),fvals.end());

//	sam.print();

}



/*
	SAANCD_IND_CONNECTS_TO_NODE_WITH_NO_FRAGS,
	SAANCD_IND_CONNECTS_TO_N_TERMINAL,
	SAANCD_IND_CONNECTS_TO_C_TERMINAL,
	SAANCD_IND_CONNECTS_C_DIGEST,
	SAANCD_IND_CONNECTS_N_DIGEST,

	SAANCD_IND_HAS_MAX_SCORE_FROM_N,  	SAANCD_SCORE_FROM_N,
	SAANCD_IND_HAS_MAX_SCORE_TO_C,      SAANCD_SCORE_TO_C,
	SAANCD_IND_HAS_MAX_SCORE_TO_DIGEST, SAANCD_SCORE_TO_DIGEST,

	SAANCD_NODE_MASS_DIFF_FROM_N,     SAANCD_NODE_SQR_MASS_DIFF_FROM_N,
	SAANCD_NODE_MASS_DIFF_TO_C,     SAANCD_NODE_SQR_MASS_DIFF_TO_C,
	SAANCD_NODE_MASS_DIFF_TO_DIGEST,     SAANCD_NODE_SQR_MASS_DIFF_TO_DIGEST,
*/
void ScoreSingleAA::fill_fval_vector_for_ncd_edge(const PrmGraph& prm, int me_idx, 
						int* variant_ptr, ME_Regression_Sample& sam) const
{
	const Config *config = prm.get_config();
	const vector<mass_t>& aa2mass =config->get_aa2mass();

	vector<fval>& fvals = sam.f_vals;

	if (*variant_ptr != 1)
	{
		cout << "Error: using a single aa score for a variant with " << *variant_ptr << " aas." << endl;
		exit(1);
	}

	const int var_aa = *(variant_ptr+1);

	const mass_t exp_mass = aa2mass[var_aa];
	const MultiEdge&  edge = prm.get_multi_edge(me_idx);
	const Node& n_node = prm.get_node(edge.n_idx);
	const Node& c_node = prm.get_node(edge.c_idx);

	mass_t mass_diff = fabs(exp_mass - c_node.mass + n_node.mass);
	mass_t sqr_diff  = mass_diff * mass_diff;

	fvals.clear();
	fvals.push_back(fval(SAANCD_CONST,1.0));

	if (n_node.breakage.fragments.size() == 0 &&
		c_node.breakage.fragments.size() == 0)
		fvals.push_back(fval(SAANCD_IND_CONNECTS_TO_NODE_WITH_NO_FRAGS,1.0));

	
	if (n_node.type == NODE_N_TERM)
	{
		fvals.push_back(fval(SAANCD_IND_CONNECTS_TO_N_TERMINAL,1.0));
		if (n_node.idx_max_out_score_node == edge.c_idx)
			fvals.push_back(fval(SAANCD_IND_HAS_MAX_SCORE_FROM_N,1.0));

		fvals.push_back(fval(SAANCD_SCORE_FROM_N,c_node.score));
		fvals.push_back(fval(SAANCD_NODE_MASS_DIFF_FROM_N,mass_diff));
		fvals.push_back(fval(SAANCD_NODE_SQR_MASS_DIFF_FROM_N,sqr_diff));
	}

	if (c_node.type == NODE_C_TERM)
	{
		fvals.push_back(fval(SAANCD_IND_CONNECTS_TO_C_TERMINAL,1.0));
		if (c_node.idx_max_in_score_node == edge.n_idx)
			fvals.push_back(fval(SAANCD_IND_HAS_MAX_SCORE_TO_C,1.0));

		fvals.push_back(fval(SAANCD_SCORE_TO_C,n_node.score));
		fvals.push_back(fval(SAANCD_NODE_MASS_DIFF_TO_C,mass_diff));
		fvals.push_back(fval(SAANCD_NODE_SQR_MASS_DIFF_TO_C,sqr_diff));
	}

	if (n_node.type == NODE_DIGEST && c_node.type != NODE_C_TERM)
	{
		fvals.push_back(fval(SAANCD_IND_CONNECTS_TO_DIGEST,1.0));
		if (c_node.idx_max_in_score_node == edge.n_idx)
			fvals.push_back(fval(SAANCD_IND_HAS_MAX_SCORE_TO_DIGEST,1.0));

		fvals.push_back(fval(SAANCD_SCORE_TO_DIGEST,n_node.score));
		fvals.push_back(fval(SAANCD_NODE_MASS_DIFF_TO_DIGEST,mass_diff));
		fvals.push_back(fval(SAANCD_NODE_SQR_MASS_DIFF_TO_DIGEST,sqr_diff));
	}

	if (c_node.type == NODE_DIGEST && n_node.type != NODE_N_TERM)
	{
		fvals.push_back(fval(SAANCD_IND_CONNECTS_TO_DIGEST,1.0));
		if (c_node.idx_max_in_score_node == edge.n_idx)
			fvals.push_back(fval(SAANCD_IND_HAS_MAX_SCORE_TO_DIGEST,1.0));

		fvals.push_back(fval(SAANCD_SCORE_TO_DIGEST,n_node.score));
		fvals.push_back(fval(SAANCD_NODE_MASS_DIFF_TO_DIGEST,mass_diff));
		fvals.push_back(fval(SAANCD_NODE_SQR_MASS_DIFF_TO_DIGEST,sqr_diff));
	}

	const int charge = prm.get_source_spectrum()->get_charge();

	if (charge == 1)
	{
//		fvals.push_back(fval(SAANCD_IND_CHARGE1,1.0));
	}
	else if (charge == 2)
	{
//		fvals.push_back(fval(SAANCD_IND_CHARGE2,1.0));
	}
	else if (charge == 3)
	{
//		fvals.push_back(fval(SAANCD_IND_CHARGE3,1.0));
	}

	sort(fvals.begin(),fvals.end());
}



// Trains te models for aas from the files in the mgf list
// if the spectra have no charge, assume they have the supplied charge
// the edges that are used are only ones with two good inner nodes (with frags in each,
// not connected to the terminals
void ScoreSingleAA::train_saa_models(const FileManager& fm, Model *model,float all_aa_ratio)
{
	FileSet fs;

	fs.select_all_files(fm);

	Config *config = model->get_config();

	saa_model_aas = config->get_session_aas();
	saa_model_aas.push_back(Gap);

	const int num_saa_models = saa_model_aas.size();
	const int gap_model_idx = num_saa_models -1;

	// one set of models and one set of samples for each session aa
	vector< ME_Regression_DataSet >  aa_datasets;
	vector<int> sam_counts_1,sam_counts_2,sam_counts_3;

	aa_datasets.resize(num_saa_models);

	saa_me_models.resize(num_saa_models,NULL);
	
	sam_counts_1.resize(num_saa_models,0);
	sam_counts_2.resize(num_saa_models,0);
	sam_counts_3.resize(num_saa_models,0);

	int i;
	for (i=0; i<num_saa_models; i++)
	{
		aa_datasets[i].clear();
		aa_datasets[i].num_classes=2;
	}
	

	int counter=0;

	while (1)
	{
		Spectrum s;
		PrmGraph prm;
		vector<int> correct_node_idxs, mirror_node_idxs, correct_single_edge_idxs;
		vector<bool> ind_correct_idxs, ind_mirror_idxs;
		Peptide pep;
		int i;

		if (! fs.get_next_spectrum(fm,config,&s) )
			break;

		if (counter++ == 10000)
			break;

		correct_single_edge_idxs.clear();
		

		s.set_org_pm_with_19(s.get_true_mass_with_19());
		if (s.get_charge() ==0)
		{
			cout << "Warning: charge 0 " << s.get_file_name() << endl;
			continue;
		}
		
		model->init_model_for_scoring_spectrum(&s);

		prm.create_graph_from_spectrum(model,&s,s.get_org_pm_with_19());

		pep = s.get_peptide();
		prm.get_all_correct_node_idxs(pep ,correct_node_idxs);
		prm.get_all_mirror_node_idxs(pep, mirror_node_idxs);
		ind_correct_idxs.resize(prm.get_num_nodes(),false);
		ind_mirror_idxs.resize(prm.get_num_nodes(),false);
		
		for (i=0; i<correct_node_idxs.size(); i++)
		{
			if (correct_node_idxs[i]>=0)
				ind_correct_idxs[correct_node_idxs[i]]=true;

			if (mirror_node_idxs[i]>=0)
				ind_mirror_idxs[mirror_node_idxs[i]]=true;
		}

		// collect correct edge variants and semi correct (connect to a good node)
		const vector<int> pep_aas = pep.get_amino_acids();

		int a;
		for (a=0; a<pep_aas.size(); a++)
		{
			int aa = pep_aas[a];
			if (aa== Ile)
				aa= Leu;

			const int model_idx = get_saa_model_idx(aa);

			if (correct_node_idxs[a]>=0)
			{
				const Node& n_node = prm.get_node(correct_node_idxs[a]);
			
				if (n_node.breakage.fragments.size()==0)
					continue;

				const vector<int>& out_edges = n_node.out_edge_idxs;
				
				int e;
				int num_out_single_edges=0;
				for (e=0; e<out_edges.size(); e++)
					if (prm.get_multi_edge(out_edges[e]).num_aa == 1)
						num_out_single_edges++;

				if (num_out_single_edges <= 0)
					continue;

				double thresh_for_bad_edge = 2.0 / (1.0 + num_out_single_edges);
				
				// add positive and semi negative nodes for the outgoing edge
				for (e=0; e<out_edges.size(); e++)
				{
					const int me_idx = n_node.out_edge_idxs[e];
					const MultiEdge& me = prm.get_multi_edge(me_idx);
					if (me.num_aa != 1)
						continue;

					if (! is_inner_aa_edge(prm,me_idx))
						continue;

					const Node& c_node = prm.get_node(me.c_idx);
					if (c_node.breakage.fragments.size()==0)
						continue;

					int v;
					for (v=0; v<me.variant_ptrs.size(); v++)
					{
						int *v_ptr = me.variant_ptrs[v];
						if (*v_ptr == 1 && *(v_ptr+1) == aa)
						{
							// add edge as a good sample
							ME_Regression_Sample sam;

							sam.label=0;
							sam.weight = 3.0;
							fill_fval_vector_for_inner_edge(prm, me_idx, v_ptr, sam);

							if (model_idx != gap_model_idx  && 
								(sam_counts_1[model_idx]<15 || 
								aa_datasets[model_idx].samples.size()< MAX_AA_SAMPLES) )
								aa_datasets[model_idx].add_sample(sam);

							if (sam_counts_1[gap_model_idx]<15 || (my_random()< all_aa_ratio &&
								aa_datasets[gap_model_idx].samples.size()< MAX_AA_SAMPLES))
							{
								aa_datasets[gap_model_idx].add_sample(sam);
								sam_counts_1[gap_model_idx]++;
							}

							correct_single_edge_idxs.push_back(me_idx);
							sam_counts_1[model_idx]++;
							break;
						}
							// otherwise add the node as an incorrect node
						else if (*v_ptr == 1 && ! ind_correct_idxs[me.c_idx])
						{
							if (my_random()<thresh_for_bad_edge)
							{
									// add edge as a good sample
								ME_Regression_Sample sam;

								sam.label=1;
								sam.weight=1.0;

								fill_fval_vector_for_inner_edge(prm, me_idx, v_ptr, sam);

								if (aa_datasets[model_idx].samples.size() < MAX_AA_SAMPLES)
								{
									aa_datasets[model_idx].add_sample(sam);
									sam_counts_2[model_idx]++;
								}

								if (my_random()< all_aa_ratio && 
									aa_datasets[gap_model_idx].samples.size() < MAX_AA_SAMPLES)
									aa_datasets[gap_model_idx].add_sample(sam);
							}
						}
					}
				}

				// add negtive semi correct edges for the incoming edges
				const vector<int>& in_edges = n_node.in_edge_idxs;
				int num_in_single_edges=0;
				for (e=0; e<in_edges.size(); e++)
					if (prm.get_multi_edge(in_edges[e]).num_aa == 1)
						num_in_single_edges++;

				if (num_in_single_edges <= 0)
					continue;

				thresh_for_bad_edge = 2.0 / (1.0 + num_in_single_edges);
				
				for (e=0; e<in_edges.size(); e++)
				{
					const int me_idx = in_edges[e];
					const MultiEdge& me = prm.get_multi_edge(me_idx);
					if (me.num_aa != 1)
						continue;

					if (! is_inner_aa_edge(prm,me_idx))
						continue;

					const Node& n_node = prm.get_node(me.n_idx);
					if (n_node.breakage.fragments.size()==0)
						continue;

					int v;
					for (v=0; v<me.variant_ptrs.size(); v++)
					{
						int *v_ptr = me.variant_ptrs[v];
						if (*v_ptr == 1 && *(v_ptr+1) == aa)
							continue;

							// otherwise add the node as an incorrect node
						if (*v_ptr == 1 && ! ind_correct_idxs[me.n_idx])
						{
							if (ind_mirror_idxs[me.n_idx] && ind_mirror_idxs[me.c_idx])
							{
								if (aa != Pro && aa != Gly)
									continue;

								thresh_for_bad_edge =2.0; // so we add this edge for sure
								                          // should be able to tell the difference
							}

							if (my_random()<thresh_for_bad_edge)
							{
								// add edge as a good sample
								int aa = *(v_ptr+1);
								ME_Regression_Sample sam;

								sam.label=1;
								fill_fval_vector_for_inner_edge(prm, me_idx, v_ptr, sam);

								if (aa_datasets[model_idx].samples.size()>MAX_AA_SAMPLES)
								{
									aa_datasets[model_idx].add_sample(sam);
									sam_counts_2[model_idx]++;
								}

								if (my_random()< all_aa_ratio &&
									aa_datasets[gap_model_idx].samples.size()>MAX_AA_SAMPLES)
									aa_datasets[gap_model_idx].add_sample(sam);
							}
						}
					}
				}
			}
		}


		// add samples of random edges (make sure they are not correct ones!)
		// will add randomly selected edges upto twice the number of good edges added
		// give these samples a lower weight of 1

		const vector<MultiEdge>& edges = prm.get_multi_edges();
		vector<int> single_edge_idxs;
		
		for (i=0; i<edges.size(); i++)
			if (edges[i].num_aa == 1 && ! ind_correct_idxs[edges[i].n_idx] &&
				! ind_correct_idxs[edges[i].c_idx])
				single_edge_idxs.push_back(i);

		
		int num_good_single_edges = correct_single_edge_idxs.size();
		int num_bad_single_edges  = single_edge_idxs.size() - num_good_single_edges;

		double rand_thresh = 1.0;
		if (num_bad_single_edges>num_good_single_edges)
		{
			rand_thresh = 2.0*((double)num_good_single_edges/num_bad_single_edges);

		}

		
		for (i=0; i<single_edge_idxs.size(); i++)
		{
			bool add_anyway = false;
			int me_idx = single_edge_idxs[i];
			const MultiEdge& edge = edges[me_idx];

			if (! is_inner_aa_edge(prm,me_idx))
						continue;

			if (ind_correct_idxs[edge.n_idx] && ind_correct_idxs[edge.c_idx])
				continue;

			if (ind_mirror_idxs[edge.n_idx] && ind_mirror_idxs[edge.c_idx])
			{
				int aa = *(edge.variant_ptrs[0]+1);
				if (aa != Pro && aa != Gly)
						continue;

				add_anyway = true;; // so we add this edge for sure
				                      // should be able to tell the difference with these aas
			}

			// just to make sure, check that this edge was not added as a correct one
			int j;
			for (j=0; j<correct_single_edge_idxs.size(); j++)
				if (correct_single_edge_idxs[j] == me_idx)
					break;

			if (j<correct_single_edge_idxs.size())
				continue;

			if (add_anyway || my_random() < rand_thresh)
			{
				int v;
				for (v=0; v<edge.variant_ptrs.size(); v++)
				{
					int *v_ptr = edge.variant_ptrs[v];
					if (*v_ptr != 1)
						continue;

					ME_Regression_Sample sam;
					int aa = *(v_ptr+1);

					int model_idx = get_saa_model_idx(aa);

					sam.label=1;
					sam.weight = 1.0;
					fill_fval_vector_for_inner_edge(prm, me_idx, v_ptr, sam);

					if (aa_datasets[model_idx].samples.size()< MAX_AA_SAMPLES)
					{
						aa_datasets[model_idx].add_sample(sam);
						sam_counts_3[model_idx]++;
					}

					if (my_random()< all_aa_ratio &&
						aa_datasets[gap_model_idx].samples.size()< MAX_AA_SAMPLES)
						aa_datasets[gap_model_idx].add_sample(sam);
				}
			}
		}
	}
	
	string file = config->get_resource_dir() + "/" + model->get_model_name() + "_SAA.txt";
	ofstream ofs(file.c_str());

	int num_skipped_models=0;

	for (i=0; i<num_saa_models-1; i++)
		if (sam_counts_1[i]<15)
			num_skipped_models++;

	ofs << num_saa_models - num_skipped_models << endl;

	for (i=0; i<num_saa_models; i++)
	{
		int model_aa = saa_model_aas[i];
		
		cout << endl << "-----------------------------------------------" << endl;
		cout << "DS: " << config->get_aa2label()[model_aa] << endl;

		if (1<gap_model_idx && sam_counts_1[i]<15)
		{
			cout << "Insuffcient number of positive samples: " << sam_counts_1[i] << 
				" skipping model..." << endl;
			continue;
		}

		aa_datasets[i].tally_samples();
		aa_datasets[i].num_features = SAA_NUM_FIELDS;

		if (aa_datasets[i].num_samples<20 || 
			aa_datasets[i].class_weights[0]/aa_datasets[i].total_weight<0.00001 ||
			aa_datasets[i].class_weights[0]/aa_datasets[i].total_weight>0.99999)
			continue;

		double total_sams= sam_counts_1[i] + sam_counts_2[i] + sam_counts_3[i];

		if (total_sams>0)
		{
			cout << "sample breakdown:" << endl;
			cout << setprecision(4) << sam_counts_1[i] / total_sams << " " << 
				sam_counts_2[i] / total_sams << " " << sam_counts_3[i] / total_sams << endl;
		}

		if (aa_datasets[i].class_weights[0]/aa_datasets[i].total_weight < 0.02)
			aa_datasets[i].calibrate_class_weights(0.02);

		aa_datasets[i].print_summary();
		aa_datasets[i].print_feature_summary();
		cout << endl;
		
		saa_me_models[i] = new ME_Regression_Model;

		if (! saa_me_models[i]->train_cg(aa_datasets[i],500))
			saa_me_models[i]->set_weigts_for_const_prob(aa_datasets[i].class_weights[0]/aa_datasets[i].total_weight);

		ofs << config->get_aa2label()[model_aa] << endl;
		saa_me_models[i]->write_regression_model(ofs);
		saa_me_models[i]->print_ds_probs(aa_datasets[i]);
	}
}






// trains models of edges that connect to the terminals or digest nodes and have one
// node without fragments assigned to it.
void ScoreSingleAA::train_saancd_models(const FileManager& fm, Model *model)
{
	FileSet fs;
	
	Config *config = model->get_config();

	fs.select_all_files(fm);

	const vector<int>& n_term_digest_aas = config->get_n_term_digest_aas();
	const vector<int>& c_term_digest_aas = config->get_c_term_digest_aas();

	saancd_model_aas.clear();
	int i;
	for (i=0; i<n_term_digest_aas.size(); i++)
		saancd_model_aas.push_back(n_term_digest_aas[i]);
	for (i=0; i<c_term_digest_aas.size(); i++)
		saancd_model_aas.push_back(c_term_digest_aas[i]);

	saancd_model_aas.push_back(Gap);

	const int num_saancd_models = saancd_model_aas.size();
	const int gap_model_idx = saancd_model_aas.size()-1;
	

	// one set of models and one set of samples for each session aa
	vector< ME_Regression_DataSet >  aa_datasets;

	aa_datasets.resize(num_saancd_models);

	saancd_me_models.resize(num_saancd_models,NULL);
	

	for (i=0; i<num_saancd_models; i++)
	{
		aa_datasets[i].clear();
		aa_datasets[i].num_classes=2;
	}
	
	vector<int> num_correct_sams;
	num_correct_sams.resize(num_saancd_models,0);

	int counter=0;

	while (1)
	{
		Spectrum s;
		PrmGraph prm;
		vector<int> correct_node_idxs, mirror_node_idxs, correct_single_edge_idxs;
		vector<int> peptide_aas;
		vector<bool> ind_correct_idxs;
		Peptide pep;
		int i;

		if (! fs.get_next_spectrum(fm,config,&s) )
			break;

		correct_single_edge_idxs.clear();
		
		if (counter++ == 10000)
			break;

		s.set_org_pm_with_19(s.get_true_mass_with_19());
		if (s.get_charge() ==0)
			continue;
		
		model->init_model_for_scoring_spectrum(&s);

		prm.create_graph_from_spectrum(model,&s,s.get_org_pm_with_19());

//		prm.rank_nodes_according_to_score();

//		prm.set_idxs_max_in_out_for_nodes();

		pep = s.get_peptide();
		pep.convert_IL();
		peptide_aas = pep.get_amino_acids();
		prm.get_all_correct_node_idxs(pep ,correct_node_idxs);
		prm.get_all_mirror_node_idxs(pep, mirror_node_idxs);
		ind_correct_idxs.resize(prm.get_num_nodes(),false);
		
		for (i=0; i<correct_node_idxs.size(); i++)
			if (correct_node_idxs[i]>=0)
				ind_correct_idxs[correct_node_idxs[i]]=true;

		// add samples of random edges (make sure they are not correct ones!)
		// will add randomly selected edges upto twice the number of good edges added
		// give these samples a lower weight of 1

		const vector<MultiEdge>& edges = prm.get_multi_edges();
	
		for (i=0; i<edges.size(); i++)
		{
			const MultiEdge& edge = edges[i];

			if (edge.num_aa != 1)
				continue;

			if (is_inner_aa_edge(prm,i))
				continue;

			const Node& c_node = prm.get_node(edge.c_idx);
			
			int v;
			for (v=0; v<edge.variant_ptrs.size(); v++)
			{
				bool aa_has_model = false;
				int *v_ptr = edge.variant_ptrs[v];

				int var_aa = *(v_ptr+1);
			
				int model_idx = get_saancd_model_idx(var_aa);
				

				if (aa_datasets[model_idx].samples.size()< MAX_AA_SAMPLES &&
					ind_correct_idxs[edge.n_idx] && ind_correct_idxs[edge.c_idx])
				{
					// check if this is the correct aa
					int j;
					for (j=0; j<correct_node_idxs.size(); j++)
						if (correct_node_idxs[j] == edge.n_idx)
							break;

					if (j== correct_node_idxs.size())
					{
						cout << "Error: mismatch with correct node idx!" << endl;
						exit(1);
					}

					if (peptide_aas[j] != var_aa)  // make sure this is the correct variant
						continue;


					// add edge as a good sample
					ME_Regression_Sample sam;

					sam.label=0;
					sam.weight = 1.0;
					fill_fval_vector_for_ncd_edge(prm, i, v_ptr, sam);

					aa_datasets[model_idx].add_sample(sam);
					num_correct_sams[model_idx]++;

					if (model_idx != gap_model_idx)
					{
						aa_datasets[gap_model_idx].add_sample(sam);
						num_correct_sams[gap_model_idx]++;
					}
				}
				else
				{
					ME_Regression_Sample sam;

					sam.label=1;
					sam.weight = 1.0;
					fill_fval_vector_for_ncd_edge(prm, i, v_ptr, sam);

					aa_datasets[model_idx].add_sample(sam);
				}
			}
		}
	}
	
	string file = config->get_resource_dir() + "/" + model->get_model_name() + "_SAANCD.txt";
	ofstream ofs(file.c_str());


	int num_skipped_models=0;
	for (i=0; i<saancd_model_aas.size()-1; i++)
		if (num_correct_sams[i]<15)
			num_skipped_models++;

	ofs << num_saancd_models - num_skipped_models << endl;

	for (i=0; i<saancd_model_aas.size(); i++)
	{
		int model_aa = saancd_model_aas[i];
		cout << "DS: " << config->get_aa2label()[model_aa] << endl;

		if (num_correct_sams[i] < 15 && i<gap_model_idx)
		{
			cout << "Too few positive samples.. skipping..." << endl;
			continue;
		}

		aa_datasets[i].tally_samples();
		aa_datasets[i].num_features = SAANCD_NUM_FIELDS;

		
	
		if (aa_datasets[i].class_weights[0]/aa_datasets[i].total_weight < 0.02)
			aa_datasets[i].calibrate_class_weights(0.02);

		aa_datasets[i].print_summary();
		aa_datasets[i].print_feature_summary();
		cout << endl;
		
		saancd_me_models[i] = new ME_Regression_Model;

		if (! saancd_me_models[i]->train_cg(aa_datasets[i],400))
			saancd_me_models[i]->set_weigts_for_const_prob(aa_datasets[i].class_weights[0]/aa_datasets[i].total_weight);


		ofs << config->get_aa2label()[model_aa] << endl;
		saancd_me_models[i]->write_regression_model(ofs);
		saancd_me_models[i]->print_ds_probs(aa_datasets[i]);
	}
}

// Reads the SAA model (single amino acids in the middle of the peptide)
// and also the SAANCD (single amino acids near terminals or digest nodes missing peaks).
// If an amino acid does not have the a specific model, it is assigned the
// model for Gap (the default model)
bool ScoreSingleAA::read_model(Config *config, char *model_name)
{
	string saa_path = config->get_resource_dir() + "/" + model_name + "_SAA.txt";

	fstream saa_stream;
	saa_stream.open(saa_path.c_str(),ios::in);

	if (! saa_stream.good())
	{
		cout << "Warning: couldn't open file for reading: " << saa_path << endl;
		return false;
	}

	char buff[64];
	int num_models;

	saa_stream.getline(buff,64);
	num_models = atoi(buff);
	if (num_models<1 || num_models>1000)
	{
		cout << "Error reading model file " << saa_path << endl;
		exit(1);
	}

	saa_me_models.resize(num_models,NULL);
	saa_model_aas.clear();
	int c=-1;
	while (1)
	{ 
		saa_stream.getline(buff,64);
		if (saa_stream.eof())
			break;
		c++;
		string aa_label(buff);
		int aa = config->get_aa_from_label(aa_label);	

		if (aa<0)
			continue;

		saa_model_aas.push_back(aa);

		saa_me_models[c] = new ME_Regression_Model;

		saa_me_models[c]->read_regression_model(saa_stream);
	}

	int gap_model_idx = get_saa_model_idx(Gap);
	if (gap_model_idx <0)
	{
		cout << "Error: no model for gap aa in SAA file!" << endl;
		exit(1);
	}

	int i;
	for (i=0; i<saa_me_models.size(); i++)
		if (! saa_me_models[i])
			saa_me_models[i] = saa_me_models[gap_model_idx]; // the default model for all aa

	saa_stream.close();

	// read the saancd model
	string saancd_path = config->get_resource_dir() + "/" + model_name + "_SAANCD.txt";

	fstream saancd_stream;
	saancd_stream.open(saancd_path.c_str(),ios::in);

	if (! saancd_stream.good())
	{
		cout << "Error: couldn't open file for reading: " << saa_path << endl;
		exit(1);
	}

	saancd_stream.getline(buff,64);
	num_models = atoi(buff);
	if (num_models<1 || num_models>1000)
	{
		cout << "Error reading model file " << saa_path << endl;
		exit(1);
	}

	saancd_me_models.resize(num_models,NULL);
	saancd_model_aas.clear();

	c=-1;
	while (1)
	{
		saancd_stream.getline(buff,64);
		if (saancd_stream.eof())
			break;

		c++;
		string aa_label(buff);
		int aa = config->get_aa_from_label(aa_label);	

		if (aa<0)
			continue;

		saancd_model_aas.push_back(aa);

		saancd_me_models[c] = new ME_Regression_Model;

		saancd_me_models[c]->read_regression_model(saancd_stream);

	}

	gap_model_idx = get_saancd_model_idx(Gap);
	if (gap_model_idx <0)
	{
		cout << "Error: no model for gap aa in SAANCD file!" << endl;
		exit(1);
	}

	for (i=0; i<saancd_me_models.size(); i++)
		if (! saancd_me_models[i])
			saancd_me_models[i] = saancd_me_models[gap_model_idx]; // the default model for all aa

	saancd_stream.close();

	ind_model_was_read = true;

	return true;
//	cout << "Read: " << saa_path << endl;
}


