#include "TagModel.h"
#include "DeNovoDp.h"
#include "DeNovoSolutions.h"
#include "auxfun.h"

/******************************************************************************
// creates the feature vector for tags. It relies on the previously calculated
// probabilities calculated for the amino acids
*******************************************************************************/
void TagModel::fill_tag_vector(const SeqPath& path, const PrmGraph& prm,
							   int denovo_rank, int score_rank, score_t top_score,
							   ME_Regression_Sample& sam) const

{
	Config *config = prm.get_config();
	const mass_t tolerance = config->get_tolerance();
	const int tag_length= path.get_num_aa();
	Spectrum *spec = prm.get_source_spectrum();
	const float pm_with_19= (spec->get_corrected_pm_with_19() > 0 ?
							 spec->get_corrected_pm_with_19() :
							 spec->get_org_pm_with_19() );
	if (pm_with_19<0)
	{
		cout << "Error: precursor mass < 0!" << endl;
		exit(1);
	}

	vector<fval>& fvals = sam.f_vals;
	const int& tag_n_idx = path.positions[0].node_idx;
	const int& tag_c_idx = path.positions[tag_length].node_idx;

	if (tag_length<2)
	{
		cout << "Tags must have length at least 2!" << endl;
		exit(1);
	}
	
	fvals.clear();
	fvals.push_back(fval(SAA_CONST,1.0));

	if (denovo_rank>=0)
	{
		fvals.push_back(fval(TAGM_HAS_DENOVO_RANK,1.0));
		fvals.push_back(fval(TAGM_DENOVO_RANK,denovo_rank));
		fvals.push_back(fval(TAGM_LOG_DENOVO_RANK,log(1.0+denovo_rank)));
	}
	

	float log_multi_path_rank =  log(1.0+path.multi_path_rank);
	
	if (score_rank<20)
		fvals.push_back(fval(TAGM_SCORE_RANK, score_rank));

	fvals.push_back(fval(TAGM_MULTI_PATH_RANK, path.multi_path_rank));
	fvals.push_back(fval(TAGM_LOG_SCORE_RANK, log(1.0+score_rank)));
	fvals.push_back(fval(TAGM_LOG_MULTI_PATH_RANK, log_multi_path_rank));

	float  avg_score_diff = (path.path_score - top_score)/tag_length;
	float  avg_score = path.path_score/tag_length;
	fvals.push_back(fval(TAGM_PATH_SCORE, path.path_score));
	fvals.push_back(fval(TAGM_AVG_PATH_SCORE, avg_score));
	fvals.push_back(fval(TAGM_DIFF_SCORE_FROM_FIRST, path.path_score - top_score));
	fvals.push_back(fval(TAGM_AVG_DIFF_SCORE_FROM_FIRST, avg_score_diff));
	fvals.push_back(fval(TAGM_SCORE_DIFF_TIMES_LOG_RANK_DIFF,
		avg_score_diff*log_multi_path_rank));

	if (avg_score>=15.0)
	{
		fvals.push_back(fval(TAGM_IND_AVG_SCORE_MORE_15,avg_score));
	}
	else if (avg_score>=10.0)
	{
		fvals.push_back(fval(TAGM_IND_AVG_SCORE_MORE_10,avg_score));
	}
	else if (avg_score>=5.0)
	{
		fvals.push_back(fval(TAGM_IND_AVG_SCORE_MORE_5,avg_score));
	}
	else if (avg_score>=0)
	{
		fvals.push_back(fval(TAGM_IND_AVG_SCORE_MORE_0,avg_score));
	}
	else
		fvals.push_back(fval(TAGM_IND_AVG_SCORE_LOW,avg_score));


	// fill rank indicators

	if (path.multi_path_rank == 0)
	{
		fvals.push_back(fval(TAGM_IND_MULTI_PATH_RANK1,1.0));
	}
	else if (path.multi_path_rank == 1)
	{
		fvals.push_back(fval(TAGM_IND_MULTI_PATH_RANK2,log(1.0+path.multi_path_rank)));
	}
	else if (path.multi_path_rank < 5)
	{
		fvals.push_back(fval(TAGM_IND_MULTI_PATH_RANK345,log(1.0+path.multi_path_rank)));
	}
	else if (path.multi_path_rank < 10)
	{
		fvals.push_back(fval(TAGM_IND_MULTI_PATH_RANK610,log(1.0+path.multi_path_rank)));
	}
	else if (path.multi_path_rank < 20)
	{
		fvals.push_back(fval(TAGM_IND_MULTI_PATH_RANK1120,log(1.0+path.multi_path_rank)));
	}
	else
		fvals.push_back(fval(TAGM_IND_MULTI_PATH_RANK_HIGHER,log(1.0+path.multi_path_rank)));


	if (score_rank == 0)
	{
		fvals.push_back(fval(TAGM_IND_SCORE_RANK1,1.0));
	}
	else if (score_rank == 1)
	{
		fvals.push_back(fval(TAGM_IND_SCORE_RANK2,1.0));
	}
	else if (score_rank < 5)
	{
		fvals.push_back(fval(TAGM_IND_SCORE_RANK345,log(1.0+score_rank)));
	}
	else if (score_rank < 10)
	{
		fvals.push_back(fval(TAGM_IND_SCORE_RANK610,log(1.0+score_rank)));
	}
	else if (score_rank < 20)
	{
		fvals.push_back(fval(TAGM_IND_SCORE_RANK1120,log(1.0+score_rank)));
	}
	else
		fvals.push_back(fval(TAGM_IND_SCORE_RANK_HIGHER,log(1.0+score_rank)));


	int i;
	int num_skipped_nodes=0;
	int high_prob_edges=0;
	int num_edges=0;
	vector<float> aa_probs;
	vector<int>   path_aas;
	vector<float> node_scores;
	aa_probs.resize(tag_length);
	path_aas.resize(tag_length);
	
	for (i=0; i<tag_length; i++)
	{
		path_aas[i] = path.positions[0].aa;
		aa_probs[i] = path.positions[i].edge_prob;
		if (path.positions[i].edge_idx<0)
		{
			num_skipped_nodes++;
		}
		else
		{
			num_edges++;
			if (path.positions[i].edge_prob>0.5)
				high_prob_edges++;
		}
	}

	int num_nodes = 0, nodes_above_10=0, nodes_above_5=0, nodes_above_0=0, nodes_below_0=0;
	for (i=0; i<=tag_length; i++)
	{
		if (path.positions[i].node_idx>=0)
		{
			num_nodes++;
			float node_score = path.positions[i].node_score;
			if (node_score>=10)
			{
				nodes_above_10++;
			}
			else if (node_score>=5)
			{
				nodes_above_5++;
			}
			else if (node_score>=0)
			{
				nodes_above_0++;
			}
			else
				nodes_below_0++;
		}
	}

	if (num_skipped_nodes>0)
		fvals.push_back(fval(TAGM_PROP_MISSING_NODES,(float)num_skipped_nodes/(float)tag_length));
	if (nodes_above_10>0)
		fvals.push_back(fval(TAGM_PROP_NODES_SCORE_ABOVE_10,(float)nodes_above_10/(float)num_nodes));
	if (nodes_above_5>0)
		fvals.push_back(fval(TAGM_PROP_NODES_SCORE_ABOVE_5,(float)nodes_above_5/(float)num_nodes));
	if (nodes_above_0>0)
		fvals.push_back(fval(TAGM_PROP_NODES_SCORE_ABOVE_0,(float)nodes_above_0/(float)num_nodes));
	if (nodes_below_0>0)
		fvals.push_back(fval(TAGM_PROP_NODES_SCORE_BELOW_0,(float)nodes_below_0/(float)num_nodes));
	
	float pos_edge_ratio = (float)high_prob_edges/num_edges;
	if (pos_edge_ratio>0)
	{
		fvals.push_back(fval(TAGM_PROP_EDGE_POSITIVE_SCORE, pos_edge_ratio));
		int num_pos_nodes = num_nodes-nodes_below_0;
		if (num_pos_nodes>0)
			fvals.push_back(fval(TAGM_PROP_NODES_POSITIVE_TIMES_PROP_EDGES_POSITIVE, num_pos_nodes/(float)num_nodes));
	}

	bool is_near_zero = (tag_n_idx == 0);
	
	// calculate minimum and avg aa prob

	float  min_log_prob=2.0;
	float  sum_logs = 0;
	int    min_aa_idx = -1;

	int num_p=0,num_above_05=0, num_below_05 =0, num_above_07 = 0, num_below_02=0;

	for (i=0; i<tag_length; i++)
	{
		float p = aa_probs[i];
		if (p<0)
			continue;
		
		float log_prob = log(p+0.0001);
		if (min_log_prob>log_prob)
		{
			min_log_prob=log_prob;
			min_aa_idx = i;
		}

		sum_logs += log_prob;

		if (p>=0.5)
		{
			num_above_05++;
		}
		else 
			num_below_05++;

		if (p>=0.7)
			num_above_07++;

		if (p<0.2)
			num_below_02++;
		
		num_p++;
	}

	float log_avg_prob = sum_logs / tag_length;

	// if tag starts from mass 0 and minimum aa is the left most one
	bool inc_min_prob=false;
	
	if (is_near_zero &&  min_aa_idx == 0)
	{
		fvals.push_back(fval(TAGM_IND_MIN_N_TERM,1.0));
		fvals.push_back(fval(TAGM_MIN_N_LOG_MIN_PROB,min_log_prob));
		fvals.push_back(fval(TAGM_MIN_N_LOG_SUM, log_avg_prob));
		fvals.push_back(fval(TAGM_MIN_N_NUM_MISSING_NODES,num_skipped_nodes));
		inc_min_prob=true;
	}

	// if in the middle of a tag
	if (! inc_min_prob &&  min_aa_idx>0 && min_aa_idx < tag_length -1)
	{
		fvals.push_back(fval(TAGM_IND_MIN_MID,1.0));
		fvals.push_back(fval(TAGM_MIN_MID_LOG_MIN_PROB,min_log_prob));
		fvals.push_back(fval(TAGM_MIN_MID_LOG_SUM, log_avg_prob));
		fvals.push_back(fval(TAGM_MIN_MID_NUM_MISSING_NODES,num_skipped_nodes));
		inc_min_prob=true;
	}
	

	// on the edge of a tag, not in any of the above cases
	if (! inc_min_prob)
	{
		fvals.push_back(fval(TAGM_IND_MIN_ENDS,1.0));
		fvals.push_back(fval(TAGM_MIN_ENDS_LOG_MIN_PROB,min_log_prob));
		fvals.push_back(fval(TAGM_MIN_ENDS_LOG_SUM, log_avg_prob));
		fvals.push_back(fval(TAGM_MIN_ENDS_NUM_MISSING_NODES,num_skipped_nodes));

		inc_min_prob=true;
	}

	// fill prob prop indicators
	if (num_above_05>0)
	{
		fvals.push_back(fval(TAGM_IND_HAS_PROBS_ABOVE_05,1.0));
		fvals.push_back(fval(TAGM_PROP_PROBS_ABOVE_05,(float)num_above_05/num_p));
		if (num_p == num_above_05)
			fvals.push_back(fval(TAGM_IND_ALL_PROBS_ABOVE_05,1.0));
	}
	if (num_above_07>0)
	{
		fvals.push_back(fval(TAGM_IND_HAS_PROBS_ABOVE_07,1.0));
		fvals.push_back(fval(TAGM_PROP_PROBS_ABOVE_07,(float)num_above_07/num_p));
		if (num_p == num_above_07)
			fvals.push_back(fval(TAGM_IND_ALL_PROBS_ABOVE_07,1.0));
	}
	if (num_below_05>0)
	{
		fvals.push_back(fval(TAGM_IND_HAS_PROBS_BELOW_05,1.0));
		fvals.push_back(fval(TAGM_PROP_PROBS_BELOW_05,(float)num_below_05/num_p));
	}
	if (num_below_02>0)
	{
		fvals.push_back(fval(TAGM_IND_HAS_PROBS_BELOW_02,1.0));
		fvals.push_back(fval(TAGM_PROP_PROBS_BELOW_02,(float)num_below_02/num_p));
	}

	
	// find mirror
	Peptide pep;
	pep.parse_from_string(config,path.seq_str);
	pep.calc_mass(config);
	pep.reverse();


//	cout << path.seq_str << " <=> " << pep.as_string(config) << endl;
//	cout << "MASS: " << pm_with_19 << endl; // - path.c_term_mass -1.0 << endl;
//	path.print();
//	prm.print();
	SeqPath rev=prm.get_highest_scoring_subpath(pep, pm_with_19 - path.c_term_mass -1.0);
//	rev.print();
//	exit(0);

	if (rev.get_num_aa() == tag_length)
	{
	//	rev.make_seq_str(config);
	//	cout << "Mirrors: " << pm_with_19 << endl;
	//	path.print();
	//	rev.print();
		fvals.push_back(fval(TAGM_IND_HAS_MIRROR,1.0));
		fvals.push_back(fval(TAGM_MIRROR_SCORE_DIFF,(path.path_score - rev.path_score)/tag_length));
	}
	else
		fvals.push_back(fval(TAGM_IND_NO_MIRROR,1.0));

	mass_t pep_mass = pep.get_mass();
	mass_t node_diff = fabs(path.c_term_mass - path.n_term_mass - pep_mass)/tolerance;
	if (node_diff<=0.1 * tolerance)
	{
		fvals.push_back(fval(TAGM_NODE_MASS_OFFSET_01_TOL,1.0));
	}
	else if (node_diff<=0.5 * tolerance)
	{
		fvals.push_back(fval(TAGM_NODE_MASS_OFFSET_05_TOL,1.0));
	}
	else if (node_diff<=tolerance)
	{
		fvals.push_back(fval(TAGM_NODE_MASS_OFFSET_10_TOL,1.0));
	}
	else
		fvals.push_back(fval(TAGM_NODE_MASS_OFFSET_LARGE,1.0));

	
	// 

	int num_missing_digest_aa = 0;
	if (path.n_term_mass < 20)
	{
		fvals.push_back(fval(TAGM_IND_STARTS_AT_N,1.0));
		if (config->get_num_n_term_digest_aas()>0 &&
			! config->is_n_digest_aa(path_aas[0]))
			num_missing_digest_aa++;
	}

	if (path.c_term_mass>pm_with_19-25)
	{
		fvals.push_back(fval(TAGM_IND_ENDS_AT_C,1.0));
		if (config->get_num_c_term_digest_aas()>0 &&
			! config->is_n_digest_aa(path_aas[tag_length-1]))
			num_missing_digest_aa++;

	}

	if (num_missing_digest_aa>0)
		fvals.push_back(fval(TAGM_MISSING_DIGEST_AAS,1.0));

	int num_w=0 , num_q = 0, num_n=0;
	float min_w = 2.0, min_q = 2.0, min_n =2.0;

	for (i=0; i<tag_length; i++)
	{
		if (aa_probs[i]<0)
			continue;

		if (path_aas[i] == Trp)
		{
			num_w++;
			if (aa_probs[i]<min_w)
				min_w = aa_probs[i];
		}
		else if (path_aas[i] == Gln)
		{
			num_q++;
			if (aa_probs[i]<min_q)
				min_q = aa_probs[i];
		}
		else if (path_aas[i] == Asn)
		{
			num_n++;
			if (aa_probs[i]<min_n)
				min_n = aa_probs[i];
		}
	}

	if (num_w>0)
	{
		fvals.push_back(fval(TAGM_NUM_W,num_w));
		fvals.push_back(fval(TAGM_MIN_PROB_W,min_w));
	}
	if (num_q>0)
	{
		fvals.push_back(fval(TAGM_NUM_Q,num_q));
		fvals.push_back(fval(TAGM_MIN_PROB_Q,min_q));
	} 
	if (num_n>0)
	{
		fvals.push_back(fval(TAGM_NUM_N,num_n));
		fvals.push_back(fval(TAGM_MIN_PROB_N,min_n));
	} 

	int num_problematic_aas=0;
	float min_prob_problematic_aas=2.0;
	for (i=0; i<tag_length-1; i++)
	{
		if (aa_probs[i]<0)
			continue;

		if ((path_aas[i] == Ala && path_aas[i+1] == Asp) ||
			(path_aas[i] == Asp && path_aas[i+1] == Ala) ||
			(path_aas[i] == Val && path_aas[i+1] == Ser) ||
			(path_aas[i] == Ser && path_aas[i+1] == Val) ||
			(path_aas[i] == Gly && path_aas[i+1] == Glu) ||
			(path_aas[i] == Gly && path_aas[i+1] == Ala) ||
			(path_aas[i] == Ala && path_aas[i+1] == Gly) ||
			(path_aas[i] == Gly && path_aas[i+1] == Gly) )
		{
			num_problematic_aas++;
			if (aa_probs[i]<min_prob_problematic_aas)
				min_prob_problematic_aas = aa_probs[i];
		}
	}

	if (num_problematic_aas>0)
	{
		fvals.push_back(fval(TAGM_NUM_PROBLEMATIC_AAS,num_problematic_aas));
		fvals.push_back(fval(TAGM_MIN_PROB_PROBLEMATIC_AAS,min_prob_problematic_aas));
	}



	sort(fvals.begin(),fvals.end());
} 




// Trains te models for aas from the files in the mgf list
// if the spectra have no charge, assume they have the supplied charge
// the edges that are used are only ones with two good inner nodes (with frags in each,
// not connected to the terminals
void TagModel::train_models(const FileManager& fm, Model *model)
{
	Config *config = model->get_config();

	bool need_to_add_plus2_false_samples = false;
	if (config->get_pm_tolerance()<0.5)
		need_to_add_plus2_false_samples = false;

	cout << "Training tag models..." << endl;

	int max_local_tag_length = 9;  
	int max_denovo_tag_length = 9;

	int min_charge = fm.get_min_charge();
	int max_charge = fm.get_max_charge();

	
	if (local_tag_models.size()<max_charge+1)
		local_tag_models.resize(max_charge+1);

	if (denovo_tag_models.size()<max_charge+1)
		denovo_tag_models.resize(max_charge+1);

	if (charge_rank_levels.size()<max_charge+1)
		charge_rank_levels.resize(max_charge+1);

	int charge;
	int charge_with_maximum_number_of_spectra=-1;
	int max_spectra_for_a_charge = 0;
	for (charge = min_charge; charge<=max_charge; charge++)
		if (fm.get_num_spectra(charge)>max_spectra_for_a_charge)
		{
			charge_with_maximum_number_of_spectra = charge;
			max_spectra_for_a_charge=fm.get_num_spectra(charge);
		}

	
	for (charge = min_charge; charge<=max_charge; charge++)
	{
		FileSet fs;
		fs.select_files(fm,0,10000,-1,9999,charge);

		if (fs.get_total_spectra()<400 && fs.get_total_spectra()<max_spectra_for_a_charge)
			continue;

		// determine the number of models based on the number of spectra
		vector<int> rank_level_ranges;
		rank_level_ranges.clear();

		if (fs.get_total_spectra()>= 2500)
		{
			charge_rank_levels[charge].push_back(1);
			charge_rank_levels[charge].push_back(4);
			charge_rank_levels[charge].push_back(9);
			charge_rank_levels[charge].push_back(15);
			charge_rank_levels[charge].push_back(25);
			charge_rank_levels[charge].push_back(40);
			charge_rank_levels[charge].push_back(70);

		}
		else if (fs.get_total_spectra()>1500)
		{
			charge_rank_levels[charge].push_back(2);
			charge_rank_levels[charge].push_back(5);
			charge_rank_levels[charge].push_back(12);
			charge_rank_levels[charge].push_back(20);
			charge_rank_levels[charge].push_back(40);
			charge_rank_levels[charge].push_back(70);
		}
		else if (fs.get_total_spectra()>500)
		{
			charge_rank_levels[charge].push_back(3);
			charge_rank_levels[charge].push_back(10);
			charge_rank_levels[charge].push_back(25);
			charge_rank_levels[charge].push_back(50);
		}
		else
		{
			charge_rank_levels[charge].push_back(5);
			charge_rank_levels[charge].push_back(15);
			charge_rank_levels[charge].push_back(50);	
		}

		const int num_rank_levels = charge_rank_levels[charge].size()+1;

		vector< vector< ME_Regression_DataSet > > local_datasets;   // per tag length, rank level
		vector< vector< ME_Regression_DataSet >  > denovo_datasets;
		vector< vector<int> > num_corr_local_sams, num_corr_denovo_sams;


		local_datasets.resize(max_local_tag_length+1);
		denovo_datasets.resize(max_denovo_tag_length+1);
		
		num_corr_local_sams.resize(max_local_tag_length+1);
		num_corr_denovo_sams.resize(max_denovo_tag_length+1);

		local_tag_models[charge].resize(max_local_tag_length+1);
		denovo_tag_models[charge].resize(max_denovo_tag_length+1);


		int k;
		for (k=2; k<=max_local_tag_length; k++)
		{
			local_tag_models[charge][k].resize(num_rank_levels+1,NULL);
			local_datasets[k].resize(num_rank_levels+1);
			num_corr_local_sams[k].resize(num_rank_levels+1,0);

			int i;
			for (i=0; i<local_datasets[k].size(); i++)
			{
				local_datasets[k][i].clear();
				local_datasets[k][i].num_classes=2;
			}
		}

		for (k=2; k<=max_denovo_tag_length; k++)
		{
			denovo_tag_models[charge][k].resize(num_rank_levels+1,NULL);
			denovo_datasets[k].resize(num_rank_levels+1);
			num_corr_denovo_sams[k].resize(num_rank_levels+1,0);
			
			int i;
			for (i=0; i<denovo_datasets[k].size(); i++)
			{
				denovo_datasets[k][i].clear();
				denovo_datasets[k][i].num_classes=2;
			}
		}
		

		// create samples
		int counter=0;
		double rand_ratio = 2500.0/fs.get_total_spectra();

		while (1)
		{
			Spectrum s;
			PrmGraph prm, prm2;
			DeNovoDp dndp, dndp2;

			if (! fs.get_next_spectrum(fm,config,&s) )
				break;

			s.set_org_pm_with_19(s.get_true_mass_with_19());
			if (s.get_charge() ==0)
				continue;

			if (s.get_true_mass()<200)
				continue;

			if (my_random()>rand_ratio)
				continue;

			if (counter++ == 3000)
				break;


			model->init_model_for_scoring_spectrum(&s);
		
			prm.create_graph_from_spectrum(model,&s,s.get_true_mass_with_19());
			model->score_graph_edges(prm);
			dndp.fill_dp_table(&prm,25);

			if (need_to_add_plus2_false_samples)
			{
				prm2.create_graph_from_spectrum(model,&s,s.get_true_mass_with_19()+2.0);
				model->score_graph_edges(prm2);
				dndp2.fill_dp_table(&prm2,25);
			}

			Peptide pep = s.get_peptide();
			string pep_str = pep.as_string(config);

			// collect local tags	
			const int num_seq_paths = 100;

			int tag_length;
			for (tag_length=2; tag_length<=max_local_tag_length; tag_length++)
			{
				const int num_multi_paths = 100/tag_length;
				vector<MultiPath> multi_paths;
				vector<SeqPath> seq_paths;
				vector<bool> correct_multi_paths;
				int j;
			
				multi_paths.clear();	
				dndp.get_top_scoring_antisymetric_paths_with_length_limits(multi_paths, 
					num_multi_paths, tag_length , tag_length, 25);

				prm.expand_all_multi_paths(multi_paths,seq_paths, num_seq_paths);

				correct_multi_paths.resize(multi_paths.size());
				for (j=0; j<multi_paths.size(); j++)
					correct_multi_paths[j] = multi_paths[j].check_if_correct(pep,config);

				for (j=0; j<seq_paths.size(); j++)
				{
					const SeqPath& seq_path = seq_paths[j];
					if (seq_path.get_num_aa() != tag_length)
						continue;

					bool seq_correct = seq_path.check_if_correct(pep_str,config);

					if (seq_path.get_num_aa()<tag_length ||
						(correct_multi_paths[seq_path.multi_path_rank] && ! seq_correct) )
						continue;

					ME_Regression_Sample sam;
					sam.label = (seq_correct ? 0 : 1);
					sam.weight = 2.0;
					int rank_level = get_rank_level(charge,j);
					if (seq_correct)
						num_corr_local_sams[tag_length][rank_level]++;


					fill_tag_vector(seq_path,prm,seq_path.multi_path_rank,j,
											seq_paths[0].path_score,sam);

					local_datasets[tag_length][rank_level].add_sample(sam);
				}

				if (tag_length == 3)
					cout << counter << " " << tag_length << " " << 
						num_corr_local_sams[tag_length][0] << " " << local_datasets[tag_length][0].samples.size() << endl;
			}

			// collect denovo tags
			for (tag_length=2; tag_length<=max_denovo_tag_length; tag_length++)
			{
				const int num_multi_paths = 100/tag_length;
				vector<MultiPath> multi_paths;
				vector<SeqPath> seq_paths;
				vector<bool> correct_multi_paths;
				int j;
			
				multi_paths.clear();
				
				dndp.get_top_scoring_antisymetric_paths_with_length_limits(multi_paths, 
					num_multi_paths, tag_length , max_denovo_tag_length+3, 25);

				prm.expand_all_multi_paths(multi_paths,seq_paths, num_seq_paths);

				correct_multi_paths.resize(multi_paths.size());
				for (j=0; j<multi_paths.size(); j++)
					correct_multi_paths[j] = multi_paths[j].check_if_correct(pep,config);

				SeqPathHeap seq_path_heap;
				seq_path_heap.init(num_seq_paths,config->get_tolerance());

				for (j=0; j<seq_paths.size(); j++)
				{
					const SeqPath& seq_path = seq_paths[j];
					bool seq_correct = seq_path.check_if_correct(pep_str,config);

					if ( seq_path.get_num_aa()<tag_length ||
						(correct_multi_paths[seq_path.multi_path_rank] && ! seq_correct) )
						continue;
					

					// parse seq paths
					vector<SeqPath> parsed_paths;

					seq_paths[j].parse_path_to_smaller_ones(config,tag_length,
						tag_length,parsed_paths);

					if (parsed_paths.size()<=0)
						continue;

					int k;
					for (k=0; k<parsed_paths.size(); k++)
					{
						parsed_paths[k].make_seq_str(config);
						parsed_paths[k].sort_key = parsed_paths[k].path_score;
						seq_path_heap.add_path(parsed_paths[k]);
					}
				}

				sort(seq_path_heap.paths.begin(),seq_path_heap.paths.end());

				for (j=0; j<seq_path_heap.paths.size(); j++)
				{
					const SeqPath& seq_path = seq_path_heap.paths[j];
					bool seq_correct = seq_path.check_if_correct(pep_str,config);

					ME_Regression_Sample sam;
					sam.label = (seq_correct ? 0 : 1);
					sam.weight = 2.0;
					int rank_level = get_rank_level(charge,j);

					if (seq_correct)
						num_corr_denovo_sams[tag_length][rank_level]++;

					fill_tag_vector(seq_path,prm,seq_path.multi_path_rank,j,
									seq_path_heap.paths[0].path_score,sam);

					denovo_datasets[tag_length][rank_level].add_sample(sam);
				}
			}

			// perform the same as above on the prm2, and dndp2
			if (need_to_add_plus2_false_samples)
			{
				int tag_length;
				for (tag_length=2; tag_length<=max_local_tag_length; tag_length++)
				{
					const int num_multi_paths = 100/tag_length;
					vector<MultiPath> multi_paths;
					vector<SeqPath> seq_paths;
					vector<bool> correct_multi_paths;
					int j;
				
					multi_paths.clear();	
					dndp2.get_top_scoring_antisymetric_paths_with_length_limits(multi_paths, 
						num_multi_paths, tag_length , tag_length, 25);

					prm2.expand_all_multi_paths(multi_paths,seq_paths, num_seq_paths);

					correct_multi_paths.resize(multi_paths.size());
					for (j=0; j<multi_paths.size(); j++)
						correct_multi_paths[j] = multi_paths[j].check_if_correct(pep,config);

					for (j=0; j<seq_paths.size(); j++)
					{
						const SeqPath& seq_path = seq_paths[j];
						if (seq_path.get_num_aa() != tag_length)
							continue;

						bool seq_correct = seq_path.check_if_correct(pep_str,config);

						if (seq_path.get_num_aa()<tag_length ||
							(correct_multi_paths[seq_path.multi_path_rank] && ! seq_correct) )
							continue;

						ME_Regression_Sample sam;
						sam.label = (seq_correct ? 0 : 1);
						sam.weight = 1.0;
						int rank_level = get_rank_level(charge,j);
						if (seq_correct)
							num_corr_local_sams[tag_length][rank_level]++;


						fill_tag_vector(seq_path,prm2,seq_path.multi_path_rank,j,
												seq_paths[0].path_score,sam);

						local_datasets[tag_length][rank_level].add_sample(sam);
					}
				}

				// collect denovo tags
				for (tag_length=2; tag_length<=max_denovo_tag_length; tag_length++)
				{
					const int num_multi_paths = 100/tag_length;
					vector<MultiPath> multi_paths;
					vector<SeqPath> seq_paths;
					vector<bool> correct_multi_paths;
					int j;
				
					multi_paths.clear();
					
					dndp2.get_top_scoring_antisymetric_paths_with_length_limits(multi_paths, 
						num_multi_paths, tag_length , max_denovo_tag_length+3, 25);

					prm2.expand_all_multi_paths(multi_paths,seq_paths, num_seq_paths);

					correct_multi_paths.resize(multi_paths.size());
					for (j=0; j<multi_paths.size(); j++)
						correct_multi_paths[j] = multi_paths[j].check_if_correct(pep,config);

					SeqPathHeap seq_path_heap;
					seq_path_heap.init(num_seq_paths,config->get_tolerance());

					for (j=0; j<seq_paths.size(); j++)
					{
						const SeqPath& seq_path = seq_paths[j];
						bool seq_correct = seq_path.check_if_correct(pep_str,config);

						if ( seq_path.get_num_aa()<tag_length ||
							(correct_multi_paths[seq_path.multi_path_rank] && ! seq_correct) )
							continue;
						

						// parse seq paths
						vector<SeqPath> parsed_paths;

						seq_paths[j].parse_path_to_smaller_ones(config,tag_length,
							tag_length,parsed_paths);

						if (parsed_paths.size()<=0)
							continue;

						int k;
						for (k=0; k<parsed_paths.size(); k++)
						{
							parsed_paths[k].make_seq_str(config);
							parsed_paths[k].sort_key = parsed_paths[k].path_score;
							seq_path_heap.add_path(parsed_paths[k]);
						}
					}

					sort(seq_path_heap.paths.begin(),seq_path_heap.paths.end());

					for (j=0; j<seq_path_heap.paths.size(); j++)
					{
						const SeqPath& seq_path = seq_path_heap.paths[j];
						bool seq_correct = seq_path.check_if_correct(pep_str,config);

						ME_Regression_Sample sam;
						sam.label = (seq_correct ? 0 : 1);
						sam.weight = 1.0;
						int rank_level = get_rank_level(charge,j);

						if (seq_correct)
							num_corr_denovo_sams[tag_length][rank_level]++;

						fill_tag_vector(seq_path,prm2,seq_path.multi_path_rank,j,
										seq_path_heap.paths[0].path_score,sam);

						denovo_datasets[tag_length][rank_level].add_sample(sam);
					}
				}
			}

		}
	
		// Train ME models
		int length;
		for (length=2; length<local_datasets.size(); length++)
		{
			int r;

			for (r=0; r<num_rank_levels; r++)
			{
				local_datasets[length][r].tally_samples();
				local_datasets[length][r].num_features = TAGM_NUM_FIELDS;

				if (local_datasets[length][r].num_samples<20 || 
					local_datasets[length][r].class_weights[0]/local_datasets[length][r].total_weight<0.00001 ||
					local_datasets[length][r].class_weights[0]/local_datasets[length][r].total_weight>0.99999)
					continue;

				cout << endl << "-----------------------------------------------" << endl;
				cout << "DS: LOCAL: charge " << charge << "  , " << length << " r: " << r << endl;

				if (num_corr_local_sams[length][r]<8)
				{
					cout << "Insuffcient number of positive samples: " << num_corr_local_sams[length][r] << " model not trained... " << endl;
					continue;
				}
	
				if (local_datasets[length][r].class_weights[0]/local_datasets[length][r].total_weight < 0.01)
					local_datasets[length][r].calibrate_class_weights(0.01);

				local_datasets[length][r].print_summary();
				local_datasets[length][r].print_feature_summary();
				cout << endl;
				
				local_tag_models[charge][length][r] = new ME_Regression_Model;
				if (! local_tag_models[charge][length][r]->train_cg(local_datasets[length][r],500) )
					local_tag_models[charge][length][r]->set_weigts_for_const_prob(0.15);
				local_tag_models[charge][length][r]->print_ds_probs(local_datasets[length][r]);
			}
		} 

		for (length=2; length<denovo_datasets.size(); length++)
		{
			int r;
			for (r=0; r<num_rank_levels; r++)
			{
				denovo_datasets[length][r].tally_samples();
				denovo_datasets[length][r].num_features = TAGM_NUM_FIELDS;

				if (denovo_datasets[length][r].num_samples<20 || 
					denovo_datasets[length][r].class_weights[0]/denovo_datasets[length][r].total_weight<0.00001 ||
					denovo_datasets[length][r].class_weights[0]/denovo_datasets[length][r].total_weight>0.99999)
					continue;

				cout << endl << "-----------------------------------------------" << endl;
				cout << "DS: DENOVO: charge " << charge << "  , " << length << " r: " << r << endl;

				if (num_corr_denovo_sams[length][r]<8)
				{
					cout << "Insuffcient number of positive samples: " << num_corr_denovo_sams[length][r] << " model not trained... " << endl;
					continue;
				}

			
				if (denovo_datasets[length][r].class_weights[0]/denovo_datasets[length][r].total_weight < 0.01)
					denovo_datasets[length][r].calibrate_class_weights(0.01);

				denovo_datasets[length][r].print_summary();
				denovo_datasets[length][r].print_feature_summary();
				cout << endl;
				
				denovo_tag_models[charge][length][r] = new ME_Regression_Model;
				
				if (! denovo_tag_models[charge][length][r]->train_cg(denovo_datasets[length][r],500) )
					denovo_tag_models[charge][length][r]->set_weigts_for_const_prob(0.2);

				denovo_tag_models[charge][length][r]->print_ds_probs(denovo_datasets[length][r]);
			}
		} 
		

	}

	// clone models for missing charges
	for (charge=min_charge; charge<=max_charge; charge++)
	{
		if (charge_rank_levels[charge].size()==0)
		{
			if (charge<charge_with_maximum_number_of_spectra)
			{
				// look up
				int c;
				for (c=charge+1; c<=charge_with_maximum_number_of_spectra; c++)
					if (charge_rank_levels[c].size()>0)
					{
						charge_rank_levels[charge]=charge_rank_levels[c];
						denovo_tag_models[charge]=denovo_tag_models[c];
						local_tag_models[charge]=local_tag_models[c];
						break;
					}
			}
			else
			{
				// look down
				int c;
				for (c=charge-1; c>=charge_with_maximum_number_of_spectra; c--)
					if (charge_rank_levels[c].size()>0)
					{
						charge_rank_levels[charge]=charge_rank_levels[c];
						denovo_tag_models[charge]=denovo_tag_models[c];
						local_tag_models[charge]=local_tag_models[c];
						break;
					}
			}
		}
	}

	was_initialized = true;

	// write models
	write_tag_models(config);
}



bool TagModel::read_models_file(Config *config, bool local_models)
{
	string type_name = (local_models ? "LOCAL" : "DENOVO" );
	vector< vector< vector<ME_Regression_Model *> > >& me_models = ( local_models ?
											local_tag_models : denovo_tag_models);

	string file = config->get_resource_dir() + "/" + config->get_model_name() + "_" +
		type_name + "_TM.txt";

	ifstream ifs(file.c_str());

	int max_charge;
	char buff[64];

	ifs.getline(buff,64);
	max_charge=atoi(buff);

	if (max_charge<=0 || max_charge>1000)
	{
		cout << "Warning: couldn't read " << type_name << " tag models ..." << endl;
		return false;
	}

	me_models.resize(max_charge+1);
	charge_rank_levels.resize(max_charge+1);

	while (1)
	{
		ifs.getline(buff,64);
		istringstream iss(buff);
		int charge = -1, max_model_idx =-1, num_models=-1, num_ranks=-1;

		iss >> charge >> max_model_idx >> num_ranks >> num_models;

		if (ifs.gcount()<3)
			break;

		if (charge<0 || max_model_idx<0 || num_models<0 || num_ranks<0)
		{
			cout << "Warning: couldn't read model headers for " << type_name << " tag models ..." << endl;
			return false;
		}

		// read the rank levels
		charge_rank_levels[charge].clear();
		ifs.getline(buff,64);
		istringstream rss(buff);
		
		int i;
		for (i=0; i<num_ranks; i++)
		{
			int r=-1;
			rss >> r;
			if (r<0)
			{
				cout << "Warning couldn't read rank string for charge " << charge << " in model " << type_name << endl;
				return false;
			}
			charge_rank_levels[charge].push_back(r);
		}
		const int num_rank_levels = charge_rank_levels[charge].size()+1;

		me_models[charge].clear();
		me_models[charge].resize(max_model_idx+1);
		
		for (i=0; i<= max_model_idx; i++)
			me_models[charge][i].resize(num_rank_levels,NULL);
		

		for (i=0; i<num_models; i++)
		{
			ifs.getline(buff,64);
			istringstream iss(buff);

			int model_idx =-1, rank_number=-1;

			iss >> model_idx >> rank_number;

			if (model_idx<0 || model_idx>max_model_idx || rank_number<0 || rank_number>= num_rank_levels)
			{
				cout << "Warning: couldn't read model idx for " << type_name << " tag models " <<
					" model #" << i << "," << model_idx << " rank: " << rank_number << " ..." << endl;
				return false;	
			}

			me_models[charge][model_idx][rank_number] = new ME_Regression_Model;
			me_models[charge][model_idx][rank_number]->read_regression_model(ifs);
		}
	}

	// copy missing model pointers for complete charges
	int charge;
	int model_with_charge = -1;
	for (charge=1; charge<=max_charge; charge++)
	{
		if (me_models[charge].size()>0)
		{
			model_with_charge=charge;
			int c;
			for (c=charge-1; c>=1; c--)
				if (me_models[model_with_charge].size()==0)
					me_models[c]=me_models[model_with_charge];
		}
	}
	if (model_with_charge<0)
	{
		cout << "Error: found no tag models with a charge for type " << type_name << " ..." <<endl;
		exit(1);
	}

	for (charge=1; charge<=max_charge; charge++)
	{
		if (me_models[charge].size()==0)
			me_models[charge] = me_models[model_with_charge];
	}

	// check all tag lengths are represented
	for (charge=1; charge<=max_charge; charge++)
	{
		int tag_length;
		int good_tag_length=-1;
		for (tag_length =2; tag_length<me_models[charge].size(); tag_length++)
		{
			if (me_models[charge][tag_length].size()>0)
			{
				good_tag_length=tag_length;
				continue;
			}

			if (good_tag_length<0)
			{
				cout << "Error: missing model for " << type_name << " charge " << charge <<
					" length " << tag_length << endl;
				exit(1);
			}
			me_models[charge][tag_length] = me_models[charge][good_tag_length];
		}
	}

	// check that all ranks are represented
	for (charge=1; charge<=max_charge; charge++)
	{
		int tag_length;
		for (tag_length =2; tag_length<me_models[charge].size(); tag_length++)
		{
			int r;
			int good_rank=-1;
			for (r=0; r<me_models[charge][tag_length].size(); r++)
			{
				if (me_models[charge][tag_length][r])
				{
					good_rank=r;
					int j;
					for (j=r-1; j>=0; j--)
						if (! me_models[charge][tag_length][j])
							me_models[charge][tag_length][j]=me_models[charge][tag_length][r];
				}
			}

			if (good_rank<0)
			{
				cout << "Error: missing rank models for " << type_name << " charge " << charge <<
					" length " << tag_length << endl;
				exit(1);
			}

			for (r=0; r<me_models[charge][tag_length].size(); r++)
				if (! me_models[charge][tag_length][r])
					me_models[charge][tag_length][r] = me_models[charge][tag_length][good_rank];
		}
	}

	// extend me_models to reach the maximal charge of the config
	int c;
	int last_good_charge = me_models.size()-1;
	for (c=last_good_charge+1; c<=config->get_max_charge_for_size(); c++)
	{
		me_models.push_back(me_models[last_good_charge]);	
		charge_rank_levels.push_back(charge_rank_levels[last_good_charge]);
	}

	return true;

}

void TagModel::write_models_file(Config *config, bool local_models) const
{
	string type_name = (local_models ? "LOCAL" : "DENOVO" );
	const vector< vector< vector<ME_Regression_Model *> > >& me_models = ( local_models ?
											local_tag_models : denovo_tag_models);

	int max_charge;
	vector<int> max_model_idx;
	vector<int> num_models;

	max_model_idx.resize(me_models.size(),-1);
	num_models.resize(me_models.size(),0);

	int charge; 
	for (charge=0; charge<me_models.size(); charge++)
	{
		if (me_models[charge].size()>0)
		{
			int j;
			for (j=0; j<me_models[charge].size(); j++)
			{
				int k;
				for (k=0; k<me_models[charge][j].size(); k++)
					if (me_models[charge][j][k])
					{
						max_model_idx[charge]=j;
						num_models[charge]++;
					}
			}
		}
		if (max_model_idx[charge]>=0)
			max_charge=charge;
	}
	
	if (max_charge<1)
	{
		cout << "Warning: No tag models found for " << type_name << " ...\n" << endl;
		return;
	}
	

	// write model
	string file = config->get_resource_dir() + "/" + config->get_model_name() + "_" +
		type_name + "_TM.txt";

	ofstream ofs(file.c_str());
	ofs << max_charge << endl;


	for (charge=0; charge<=max_charge; charge++)
	{
		if (max_model_idx[charge]>=0)
		{
			ofs << charge << " " << max_model_idx[charge] << " " <<  
				charge_rank_levels[charge].size() << " " << num_models[charge] << endl;

			int r;
			for (r=0; r<charge_rank_levels[charge].size()-1; r++)
				ofs << charge_rank_levels[charge][r] << " ";
			ofs << charge_rank_levels[charge][r] << endl;

			int j;
			for (j=0; j<me_models[charge].size(); j++)
			{
				int k;
				for (k=0; k<me_models[charge][j].size(); k++)
				{
					if (me_models[charge][j][k])
					{
						ofs << j << " " << k << endl;
						me_models[charge][j][k]->write_regression_model(ofs);
					}
				}
			}
		}
	}
	ofs.close();
}



float TagModel::calc_seq_prob(const SeqPath& path, int charge, const PrmGraph& prm,
			int denovo_rank, int score_rank, score_t top_score, bool from_denovo) const
{
	ME_Regression_Sample sam;
	
	sam.label=0;

	fill_tag_vector(path,prm,path.multi_path_rank,score_rank,top_score,sam);

	int tag_length = path.get_num_aa();
	if (from_denovo)
	{
		if (tag_length>= denovo_tag_models[charge].size())
			tag_length = denovo_tag_models[charge].size()-1;

		int rank_idx = this->get_rank_level(charge,score_rank);

		return denovo_tag_models[charge][tag_length][rank_idx]->p_y_given_x(0,sam);
	}
	else
	{
		if (tag_length>= local_tag_models[charge].size())
				tag_length = local_tag_models[charge].size()-1;

		int rank_idx = this->get_rank_level(charge,score_rank);

		return local_tag_models[charge][tag_length][rank_idx]->p_y_given_x(0,sam);
	}


	return 0;
}



