#include "PMCSQS.h"
#include "auxfun.h"

const char * SQS_var_names[]={
"SQS_CONST", "SQS_NUM_PEAKS", "SQS_SQR_NUM_PEAKS", "SQS_IND_BELOW_1000",
"SQS_IND_MZ_ABOVE_1000", "SQS_M_OVER_Z_BELOW_1000", "SQS_M_OVER_Z_ABOVE_1000",
"SQS_PROP_UPTO2G", "SQS_PROP_UPTO5G", "SQS_PROP_UPTO10G", "SQS_PROP_MORE10G",
"SQS_PROP_INTEN_UPTO2G", "SQS_PROP_INTEN_UPTO5G", "SQS_PROP_INTEN_MORE5G",
"SQS_PROP_ISO_PEAKS", "SQS_PROP_STRONG_WITH_ISO_PEAKS", "SQS_PROP_ALL_WITH_H2O_LOSS",
"SQS_PROP_ALL_WITH_NH3_LOSS", "SQS_PROP_ALL_WITH_CO_LOSS",
"SQS_PROP_STRONG_WITH_H2O_LOSS", "SQS_PROP_STRONG_WITH_NH3_LOSS",
"SQS_PROP_STRONG_WITH_CO_LOSS", "SQS_IND_MAX_TAG_LENGTH_ABOVE_4",
"SQS_IND_MAX_TAG_LENGTH_BELOW_4", "SQS_MAX_TAG_LENGTH_ABOVE_4",
"SQS_MAX_TAG_LENGTH_BELOW_4", "SQS_PROP_INTEN_IN_TAGS", "SQS_PROP_TAGS1",
"SQS_PROP_STRONG_PEAKS_IN_TAG1", "SQS_PROP_INTEN_TAG1",
"SQS_PROP_STRONG_BELOW30_TAG1", "SQS_PROP_TAGS2", "SQS_PROP_STRONG_PEAKS_IN_TAG2",
"SQS_PROP_INTEN_TAG2", "SQS_PROP_STRONG_BELOW20_TAG2", "SQS_PROP_TAGS3",
"SQS_PROP_STRONG_PEAKS_IN_TAG3", "SQS_PROP_INTEN_TAG3",
"SQS_PROP_STRONG_BELOW10_TAG3", "SQS_PEAK_DENSE_T1_I", "SQS_PEAK_DENSE_T2_I",
"SQS_PEAK_DENSE_T3_I", "SQS_INTEN_DENSE_T1_I", "SQS_INTEN_DENSE_T2_I",
"SQS_INTEN_DENSE_T3_I", "SQS_PEAK_DENSE_H1_I", "SQS_PEAK_DENSE_H2_I",
"SQS_INTEN_DENSE_H1_I", "SQS_INTEN_DENSE_H2_I", "SQS_PEAK_DENSE_T1_II",
"SQS_PEAK_DENSE_T2_II", "SQS_PEAK_DENSE_T3_II", "SQS_INTEN_DENSE_T1_II",
"SQS_INTEN_DENSE_T2_II", "SQS_INTEN_DENSE_T3_II", "SQS_PEAK_DENSE_H1_II",
"SQS_PEAK_DENSE_H2_II", "SQS_INTEN_DENSE_H1_II", "SQS_INTEN_DENSE_H2_II",
"SQS_PROP_MZ_RANGE_WITH_33_INTEN_I", "SQS_PROP_MZ_RANGE_WITH_50_INTEN_I",
"SQS_PROP_MZ_RANGE_WITH_75_INTEN_I", "SQS_PROP_MZ_RANGE_WITH_90_INTEN_I",
"SQS_PROP_MZ_RANGE_WITH_33_INTEN_II", "SQS_PROP_MZ_RANGE_WITH_50_INTEN_II",
"SQS_PROP_MZ_RANGE_WITH_75_INTEN_II", "SQS_PROP_MZ_RANGE_WITH_90_INTEN_II",
"SQS_NUM_FRAG_PAIRS_1", "SQS_NUM_STRONG_FRAG_PAIRS_1", "SQS_NUM_C2_FRAG_PAIRS_1",
"SQS_NUM_STRONG_C2_FRAG_PAIRS_1", "SQS_NUM_FRAG_PAIRS_2",
"SQS_NUM_STRONG_FRAG_PAIRS_2", "SQS_NUM_C2_FRAG_PAIRS_2",
"SQS_NUM_STRONG_C2_FRAG_PAIRS_2", "SQS_NUM_FRAG_PAIRS_3",
"SQS_NUM_STRONG_FRAG_PAIRS_3", "SQS_NUM_C2_FRAG_PAIRS_3",
"SQS_NUM_STRONG_C2_FRAG_PAIRS_3", "SQS_PROP_OF_MAX_FRAG_PAIRS_1",
"SQS_PROP_OF_MAX_STRONG_FRAG_PAIRS_1", "SQS_PROP_OF_MAX_C2_FRAG_PAIRS_1",
"SQS_PROP_OF_MAX_STRONG_C2_FRAG_PAIRS_1", "SQS_PROP_OF_MAX_FRAG_PAIRS_2",
"SQS_PROP_OF_MAX_STRONG_FRAG_PAIRS_2", "SQS_PROP_OF_MAX_C2_FRAG_PAIRS_2",
"SQS_PROP_OF_MAX_STRONG_C2_FRAG_PAIRS_2", "SQS_PROP_OF_MAX_FRAG_PAIRS_3",
"SQS_PROP_OF_MAX_STRONG_FRAG_PAIRS_3", "SQS_PROP_OF_MAX_C2_FRAG_PAIRS_3",
"SQS_PROP_OF_MAX_STRONG_C2_FRAG_PAIRS_3", "SQS_PROP_FRAG_PAIRS_1",
"SQS_PROP_STRONG_FRAG_PAIRS_1", "SQS_PROP_C2_FRAG_PAIRS_1",
"SQS_PROP_STRONG_C2_FRAG_PAIRS_1", "SQS_PROP_FRAG_PAIRS_2",
"SQS_PROP_STRONG_FRAG_PAIRS_2", "SQS_PROP_C2_FRAG_PAIRS_2",
"SQS_PROP_STRONG_C2_FRAG_PAIRS_2", "SQS_PROP_FRAG_PAIRS_3",
"SQS_PROP_STRONG_FRAG_PAIRS_3", "SQS_PROP_C2_FRAG_PAIRS_3",
"SQS_PROP_STRONG_C2_FRAG_PAIRS_3", "SQS_NUM_FIELDS " };

const char * PMC_var_names[]={
"PMC_CONST", "PMC_DIFF_FROM_MEASURED_MZ", "PMC_ABS_DIFF_FROM_MEASURED_MZ",
"PMC_IND_IS_P0", "PMC_IND_IS_P1", "PMC_IND_IS_P2", "PMC_IND_IS_P3", 
"PMC_IND_IS_M1", "PMC_IND_IS_M2", "PMC_IND_IS_M3", "PMC_IND_IS_M4",
"PMC_IND_HAS_NO_PAIRS", "PMC_IND_HAS_NO_C2_PAIRS", "PMC_IND_HAS_NO_STRONG_PAIRS",
"PMC_IND_HAS_NO_C2_STRONG_PAIRS", "PMC_IND_HAS_PAIRS", "PMC_IND_HAS_C2_PAIRS",
"PMC_IND_HAS_STRONG_PAIRS", "PMC_IND_HAS_C2_STRONG_PAIRS", "PMC_IND_HAS_BOTH_PAIRS",
"PMC_IND_HAS_BOTH_STRONG_PAIRS", "PMC_AVG_TOL_OVER_PAIRS",
"PMC_AVG_TOL_OVER_STRONG_PAIRS", "PMC_AVG_TOL_SQR_OVER_PAIRS",
"PMC_AVG_TOL_SQR_OVER_STRONG_PAIRS", "PMC_AVG_TOL_OVER_C2_PAIRS",
"PMC_AVG_TOL_OVER_C2_STRONG_PAIRS", "PMC_AVG_TOL_SQR_OVER_C2_PAIRS",
"PMC_AVG_TOL_SQR_OVER_C2_STRONG_PAIRS", "PMC_AVG_DIFF_TOL_OVER_PAIRS",
"PMC_AVG_DIFF_TOL_OVER_STRONG_PAIRS", "PMC_AVG_DIFF_TOL_SQR_OVER_PAIRS",
"PMC_AVG_DIFF_TOL_SQR_OVER_STRONG_PAIRS", "PMC_AVG_DIFF_TOL_OVER_C2_PAIRS",
"PMC_AVG_DIFF_TOL_OVER_C2_STRONG_PAIRS", "PMC_AVG_DIFF_TOL_SQR_OVER_C2_PAIRS",
"PMC_AVG_DIFF_TOL_SQR_OVER_C2_STRONG_PAIRS", "PMC_NUM_FRAG_PAIRS",
"PMC_NUM_STRONG_FRAG_PAIRS", "PMC_NUM_C2_FRAG_PAIRS", "PMC_NUM_STRONG_C2_FRAG_PAIRS",
"PMC_IND_BEST_NUM_FRAG_PAIRS", "PMC_IND_BEST_NUM_STRONG_FRAG_PAIRS",
"PMC_IND_BEST_BOTH_PAIRS", "PMC_IND_BEST_NUM_C2_FRAG_PAIRS",
"PMC_IND_BEST_NUM_STRONG_C2_FRAG_PAIRS", "PMC_IND_BEST_BOTH_C2_PAIRS",
"PMC_IND_PAIRS_AND_MIN_TOLERANCE",		"PMC_IND_STRONG_PAIRS_AND_MIN_TOLERANCE",
"PMC_IND_C2_PAIRS_AND_MIN_TOLERANCE",	"PMC_IND_C2_STRONG_PAIRS_AND_MIN_TOLERANCE",
"PMC_LOG_DIS_PAIRS_MIN_TOL",			"PMC_LOG_DIS_STRONG_PAIRS_MIN_TOL",
"PMC_LOG_DIS_C2_PAIRS_MIN_TOL",			"PMC_LOG_DIS_C2_STRONG_PAIRS_MIN_TOL",
"PMC_PROP_NUM_PAIRS", "PMC_PROP_NUM_STRONG_PAIRS", "PMC_PROP_NUM_C2_PAIRS",
"PMC_PROP_NUM_C2_STRONG_PAIRS", "PMC_PROP_INTEN_PAIRS", "PMC_PROP_INTEN_STRONG_PAIRS",
"PMC_PROP_INTEN_C2_PAIRS", "PMC_PROP_INTEN_C2_STRONG_PAIRS", "PMC_REL_PROP_NUM_PAIRS",
"PMC_REL_PROP_NUM_STRONG_PAIRS", "PMC_REL_PROP_NUM_C2_PAIRS",
"PMC_REL_PROP_NUM_C2_STRONG_PAIRS", "PMC_REL_PROP_INTEN_PAIRS",
"PMC_REL_PROP_INTEN_STRONG_PAIRS", "PMC_REL_PROP_INTEN_C2_PAIRS",
"PMC_REL_PROP_INTEN_C2_STRONG_PAIRS", "PMC_NUM_FIELDS" };






void PMCSQS_Scorer::train_sqs_models(Config *config, char *pos_list, char *neg_list,
									 float *inp_weights)
{
	vector< vector<ME_Regression_Sample> > sqs_samples; // neg, p1, p2, p3,...

	FileManager fm_pos;
	FileManager fm_neg;
	fm_pos.init_from_list_file(config,pos_list);

	const vector<int>& spectra_counts = fm_pos.get_spectra_counts();

	this->max_charge=0;

	int num_classes_with_suffcient_spectra =0;
	int total_counts_used = 0;
	int charge;
	for (charge=1; charge<spectra_counts.size(); charge++)
	{
		if (spectra_counts[charge] >= MIN_SPECTRA_FOR_PMCSQS_MODEL)
		{
			num_classes_with_suffcient_spectra++;
			total_counts_used+=spectra_counts[charge];
			max_charge=charge;
		}
	}

	if (num_classes_with_suffcient_spectra<2)
	{
		cout << "Error: can't train SQS models!" << endl;
		cout << "Need at least 2 charges with " << MIN_SPECTRA_FOR_PMCSQS_MODEL << " spectra each!" << endl;
		exit(0);
	}

	this->set_frag_pair_sum_offset(1.00785); // b+y - PM+19

	this->set_bin_increment(0.1);

	sqs_samples.resize(max_charge+1);

	const int max_to_read_per_file = 7000;

	for (charge=0; charge<=max_charge; charge++)
	{
		BasicSpecReader bsr;
		static QCPeak peaks[5000];

		FileSet fs;
		if (charge == 0)
		{
			fm_neg.init_from_list_file(config, neg_list);
			fs.select_all_files(fm_neg);
		}
		else
		{
			fs.select_files(fm_pos,0, 100000, -1, -1, charge);
		}

		const vector<SingleSpectrumFile *>& all_ssf = fs.get_ssf_pointers();
		const int sample_label = (charge == 0 ? 1 : 0);
		const int num_samples = (all_ssf.size()<max_to_read_per_file ? all_ssf.size() :
									max_to_read_per_file);
	
		sqs_samples[charge].resize(num_samples);

		vector<int> ssf_idxs;
		if (num_samples<all_ssf.size())
		{
			choose_k_from_n(num_samples,all_ssf.size(),ssf_idxs);
		}
		else
		{
			int i;
			ssf_idxs.resize(all_ssf.size());
			for (i=0; i<all_ssf.size(); i++)
				ssf_idxs[i]=i;
		}
		

		int i;
		for (i=0; i<num_samples; i++)
		{
			SingleSpectrumFile* ssf = all_ssf[ssf_idxs[i]];
			BasicSpectrum bs;
		
			if (charge==0)
			{
				bs.num_peaks = bsr.read_basic_spec(config,fm_neg,ssf,peaks);
			}
			else
				bs.num_peaks = bsr.read_basic_spec(config,fm_pos,ssf,peaks);

			bs.peaks = peaks;
			bs.ssf = ssf;

			init_for_current_spec(config,bs);
			calculate_curr_spec_pmc_values(bs, bin_increment);
		
			fill_fval_vector_with_SQS(bs, sqs_samples[charge][i]);
			
			sqs_samples[charge][i].label = sample_label;

		//	cout << endl;
		}
	}

	// create SQS models
	this->sqs_models.resize(max_charge+1);
	for (charge =0; charge<=max_charge; charge++)
		sqs_models[charge].resize(max_charge+1,NULL);

	// assign class weights as follows:
	// t - number of classes with enough spectra for a model
	// s_i - weight of class i spectra 
	// w_i = 1/t^2 + (t-1)/t * s_i

	vector<float> class_weights;
	class_weights.resize(max_charge+1,0.0);

	if (inp_weights)
	{
		for (charge=1; charge<=max_charge; charge++)
		{
			class_weights[charge]=inp_weights[charge];
		}
	}
	else
	{
		float t = (float)num_classes_with_suffcient_spectra;
		float one_over_t = (float)(1.0/num_classes_with_suffcient_spectra);
		for (charge=1; charge<=max_charge; charge++)
		{
			float spec_weight = (float)spectra_counts[charge]/(float)total_counts_used;
			class_weights[charge] = sqrt(one_over_t) + ((t-1.0)/t)*spec_weight;
			cout << "SQS weight for charge " << charge << " : " << class_weights[charge] << endl;
		}
	}


	for (charge=1; charge<=max_charge; charge++)
	{
		ME_Regression_DataSet ds;

		ds.num_classes=2;
		ds.num_features=SQS_NUM_FIELDS;
		ds.add_samples(sqs_samples[0]);
		ds.add_samples(sqs_samples[charge]);

		// add small proportion of random samples from other charges
		// so features look for what distinguishes this from other charges too
		// not only noise

	/*	int num_to_add = (int)(sqs_samples[0].size()*0.05);
		int c;
		for (c=9; c<=3; c++)
		{
			if (c==charge)
				continue;

			int sam_idx = (int)(my_random() * sqs_samples[c].size());
			ME_Regression_Sample sam = sqs_samples[c][sam_idx];

			sam.label=1;
			ds.add_sample(sam);
		} */


		ds.tally_samples();

		float w = class_weights[charge];
		if (w>0.45)
			w=0.45;
		ds.calibrate_class_weights(w);

		cout << endl << "CHARGE " << charge << endl;

		ds.print_feature_summary(cout,SQS_var_names);

		sqs_models[charge][0]=new ME_Regression_Model;

		sqs_models[charge][0]->train_cg(ds,500);

		sqs_models[charge][0]->print_ds_probs(ds);

		// boot strap - don't use it
		//
	/*	int r;
		for (r=0; r<1; r++)
		{
			int total_pruned=0;
			ds.samples.clear();

			int i;
			for (i=0; i<sqs_samples[0].size(); i++)
			{
				float prob=sqs_models[charge]->p_y_given_x(0,sqs_samples[0][i]);
				if (prob>0.75)
				{
					total_pruned++;
				}
				else
					ds.add_sample(sqs_samples[0][i]);
			}

			cout << "Pruned " << total_pruned << endl;
		

			ds.add_samples(sqs_samples[charge]);

			ds.tally_samples();

			ds.calibrate_class_weights(class_weights[charge]);

			sqs_models[charge]->train_cg(ds,500);

			sqs_models[charge]->print_ds_probs(ds);
		} */

	}


	
	////////////////////////////////////////////
	// train model vs. model if charge1>charge2
	if (1)
	{
		int charge1,charge2;
		for (charge1=2; charge1<=max_charge; charge1++)
		{
			for (charge2=1; charge2<charge1; charge2++)
			{

				ME_Regression_DataSet ds;

				ds.num_classes=2;
				ds.num_features=SQS_NUM_FIELDS;


				ds.add_samples(sqs_samples[charge1]);

				int i;
				for (i=0; i<sqs_samples[charge2].size(); i++)
				{
					sqs_samples[charge2][i].label=1;
					ds.add_sample(sqs_samples[charge2][i]);
					sqs_samples[charge2][i].label=0;
				}

				ds.tally_samples();

				float relative_weight = class_weights[charge1]/(class_weights[charge1]+class_weights[charge2]);
				ds.calibrate_class_weights(relative_weight);

				sqs_models[charge1][charge2] = new ME_Regression_Model;

				cout << endl << "CHARGE " << charge1 << " vs " << charge2 << endl;
				ds.print_feature_summary(cout,SQS_var_names);

				sqs_models[charge1][charge2]->train_cg(ds,500);
				sqs_models[charge1][charge2]->print_ds_probs(ds);
			}
		}
	}


	////////////////////////////////////////////
	// final report on datasets
	cout << endl;

	float p_thresh = 0.1;
	int d;
	for (d=0; d<=max_charge; d++)
	{
		vector<int> counts;
		vector<int> max_counts;
		counts.resize(max_charge+1,0);
		max_counts.resize(max_charge+1,0);

		int i;
		for (i=0; i<sqs_samples[d].size(); i++)
		{
			bool above_thresh=false;
			float max_prob=0;
			int   max_class=0;
			int c;
			for (c=1; c<=max_charge; c++)
			{
				if (! sqs_models[c][0])
					continue;

				float prob = sqs_models[c][0]->p_y_given_x(0,sqs_samples[d][i]);
				if (prob>p_thresh)
				{
					counts[c]++;
					above_thresh=true;
					if (prob>max_prob)
					{
						max_prob=prob;
						max_class=c;
					}
				}
			}
			max_counts[max_class]++;

			if (! above_thresh)
				counts[0]++;
		}

		cout << d << "\t";
		for (i=0; i<=max_charge; i++)
			cout << fixed << setprecision(4) << max_counts[i]/(float)sqs_samples[d].size() << "\t";
		cout << endl;
	}



	ind_initialized_sqs = true;

	string path;
	path = config->get_resource_dir() + "/" + config->get_model_name() + "_SQS.txt";
	write_sqs_models(path.c_str());
}


void PMCSQS_Scorer::write_sqs_models(const char *path) const
{
	ofstream out_stream(path,ios::out);
	if (! out_stream.good())
	{
		cout << "Error: couldn't open pmc model for writing: " << path << endl;
		exit(1);
	}


	out_stream << sqs_models.size() << endl;

	// write ME models
	int i;
	for (i=0; i<sqs_models.size(); i++)
	{
		int j;
		for (j=0; j<sqs_models[i].size(); j++)
		{
			if (sqs_models[i][j])
			{
				out_stream << i << " " << j << endl;
				sqs_models[i][j]->write_regression_model(out_stream);
			}
		}
	}	
	out_stream.close();
}


bool PMCSQS_Scorer::read_sqs_models(Config *_config, char *file)
{
	config = _config;

	string path;
	path = config->get_resource_dir() + "/" + string(file);


	ifstream in_stream(path.c_str(),ios::in);
	if (! in_stream.good())
	{
		cout << "Warning: couldn't open pmc model for writing: " << path << endl;
		return false;
	}


	char buff[256];
	int num_charges=-1;

	in_stream.getline(buff,256);
	istringstream iss(buff);

	iss >> num_charges;

	if (max_charge <=0)
	{
		max_charge=num_charges-1;
	}
	else
	{
		if (max_charge != num_charges-1)
		{
			cout << "Error: max_charge is not consistent between PMC and SQS!" << endl;
			exit(1);
		}
	}
	
	int i;
	sqs_models.resize(num_charges);
	for (i=0; i<num_charges; i++)
		sqs_models[i].resize(num_charges,NULL);

	
	// read ME models
	
	while (in_stream.getline(buff,128))
	{
		int charge1=-1,charge2=-1;
		sscanf(buff,"%d %d",&charge1,&charge2);

		if (charge1<1 || charge2<0 || charge1>max_charge || charge2>=charge1)
		{
			cout << "Error: reading SQS, bad charge numbers in line: " << endl << buff << endl;
			exit(1);
		}

		sqs_models[charge1][charge2] = new ME_Regression_Model;
		sqs_models[charge1][charge2]->read_regression_model(in_stream);
		continue;
	}

	in_stream.close();
	this->ind_initialized_sqs = true;
	return true;
}





void PMCSQS_Scorer::write_pmc_models(const char *path) const
{
	ofstream out_stream(path,ios::out);
	if (! out_stream.good())
	{
		cout << "Error: couldn't open pmc model for writing: " << path << endl;
		exit(1);
	}

	out_stream << this->bin_increment << " " << this->frag_pair_sum_offset << endl;
	out_stream << pmc_models.size();
	out_stream << setprecision(3);
	int i;
	for (i=0; i<this->charge_mz_biases.size(); i++)
		out_stream << " " << charge_mz_biases[i];
	out_stream << endl;

	// write ME models
	for (i=0; i<pmc_models.size(); i++)
	{
		if (pmc_models[i])
		{
			out_stream << i << endl;
			pmc_models[i]->write_regression_model(out_stream);
		}
	}
	
	out_stream.close();
}


bool PMCSQS_Scorer::read_pmc_models(Config *_config, char *file)
{
	config = _config;

	string path;
	path = config->get_resource_dir() + "/" + string(file);


	ifstream in_stream(path.c_str(),ios::in);
	if (! in_stream.good())
	{
		cout << "Warning: couldn't open pmc model for writing: " << path << endl;
		return false;
	}


	char buff[256];
	int num_charges=-1;

	in_stream.getline(buff,256);
	istringstream iss1(buff);

	frag_pair_sum_offset=NEG_INF;
	bin_increment=NEG_INF;
	iss1 >> bin_increment >> this->frag_pair_sum_offset;
	if (frag_pair_sum_offset==NEG_INF || bin_increment == NEG_INF)
	{
		cout << "Error in pmc model file!" << endl;
		exit(1);
	}

	in_stream.getline(buff,256);
	istringstream iss(buff);

	iss >> num_charges;
	max_charge=num_charges-1;
	
	pmc_models.resize(num_charges,NULL);
	this->charge_mz_biases.resize(num_charges,0);

	int i;
	for (i=0; i<num_charges; i++)
		iss >> charge_mz_biases[i];
	
	// read ME models
	
	while (in_stream.getline(buff,128))
	{
		int charge=-1;
		sscanf(buff,"%d",&charge);

		if (charge<0 || charge>=num_charges)
		{
			cout << "Error: bad line in model!, got charge " << charge << endl;
			exit(1);
		}

		pmc_models[charge]=new ME_Regression_Model;
		pmc_models[charge]->read_regression_model(in_stream);
		
	}
	in_stream.close();

	this->ind_initialized_pmc = true;
	return true;
}


/******************************************************************************
Train PMC models from positive example files
*******************************************************************************/
void PMCSQS_Scorer::train_pmc_models(Config *config, char *pos_list)
{	
	
	FileManager fm;

	fm.init_from_list_file(config,pos_list);

	const vector<int>& spectra_counts = fm.get_spectra_counts();

	max_charge=0;

	int charge;
	for (charge=1; charge<spectra_counts.size(); charge++)
	{
		if (spectra_counts[charge]>=MIN_SPECTRA_FOR_PMCSQS_MODEL)
			max_charge=charge;
	}


	
	this->set_frag_pair_sum_offset(1.00785); // b+y - PM+19
	this->set_bin_increment(0.1);
	const int max_to_read_per_file = 5000;

	pmc_models.resize(max_charge+1,NULL);
	charge_mz_biases.resize(max_charge+1,0);

	for (charge=1; charge<=max_charge; charge++)
	{
		if (spectra_counts[charge]<MIN_SPECTRA_FOR_PMCSQS_MODEL)
			continue;

		BasicSpecReader bsr;
		static QCPeak peaks[5000];
		ME_Regression_DataSet pmc_ds;

		pmc_ds.samples.clear();

		FileSet fs;		
		fs.select_files(fm,0,100000,-1,-1,charge);

		const vector<SingleSpectrumFile *>& all_ssf = fs.get_ssf_pointers();
		const int num_samples = (all_ssf.size()<max_to_read_per_file ? all_ssf.size() :
									max_to_read_per_file);
	
		vector<int> ssf_idxs;
		if (num_samples<all_ssf.size())
		{
			choose_k_from_n(num_samples,all_ssf.size(),ssf_idxs);
		}
		else
		{
			int i;
			ssf_idxs.resize(all_ssf.size());
			for (i=0; i<all_ssf.size(); i++)
				ssf_idxs[i]=i;
		}

		
		
		// first find the bias in number of bins between the true m/z bin and
		// the optimal m/z bin
		mass_t total_bias=0;
		int i;
		for (i=0; i<num_samples; i++)
		{
			SingleSpectrumFile* ssf = all_ssf[ssf_idxs[i]];
			BasicSpectrum bs;
		
			bs.num_peaks = bsr.read_basic_spec(config,fm,ssf,peaks);
			bs.peaks = peaks;
			bs.ssf = ssf;

			ssf->peptide.calc_mass(config);
			
			const mass_t true_mz = (ssf->peptide.get_mass()+18.0105+(mass_t)charge)/(mass_t)charge;

			init_for_current_spec(config,bs);
			calculate_curr_spec_pmc_values(bs, bin_increment);

			// find the true_mz_bin_idx
			
			const vector<PmcStats>& pmc_stats = curr_spec_pmc_tables[charge];
			int true_mz_bin_idx=0;
			while (true_mz_bin_idx<pmc_stats.size() && pmc_stats[true_mz_bin_idx].m_over_z<true_mz)
				true_mz_bin_idx++;

			if (pmc_stats[true_mz_bin_idx].m_over_z-true_mz>true_mz-pmc_stats[true_mz_bin_idx-1].m_over_z)
				true_mz_bin_idx--;

			int opt_bin_idx = get_optimal_bin(true_mz_bin_idx, charge);

			total_bias += (pmc_stats[opt_bin_idx].m_over_z - pmc_stats[true_mz_bin_idx].m_over_z);

		//	cout << pmc_stats[opt_bin_idx].m_over_z - pmc_stats[true_mz_bin_idx].m_over_z << endl;

			if (0)
			{
				cout << "true_idx: " << true_mz_bin_idx << "\t" << "opt_idx: " << opt_bin_idx << endl;
			
				ssf->peptide.calc_mass(config);
				mass_t true_mz = (ssf->peptide.get_mass() + 18.01 + charge)/charge;
				cout << ssf->peptide.as_string(config) << "  -  true:" << true_mz << "\t" << "obs:" << ssf->m_over_z << endl;
				int j;
				for (j=0; j<curr_spec_pmc_tables[charge].size(); j++)
				{
					const PmcStats& pmc = curr_spec_pmc_tables[charge][j];
					cout << j << "\t" << pmc.m_over_z <<"\t" << pmc.num_frag_pairs << "\t" << pmc.num_c2_frag_pairs <<
						"\t" << pmc.tol_frag_pairs << "\t" << pmc.tol_c2_frag_pairs << endl;
				}
				cout << endl;
			}

		}

		mass_t mz_bias = total_bias / (mass_t)num_samples;
		this->charge_mz_biases[charge]=mz_bias;

		cout << "m/z bias: " << setprecision(3) << mz_bias << endl;

		for (i=0; i<num_samples; i++)
		{
			SingleSpectrumFile* ssf = all_ssf[ssf_idxs[i]];
			BasicSpectrum bs;
		
			bs.num_peaks = bsr.read_basic_spec(config,fm,ssf,peaks);
			bs.peaks = peaks;
			bs.ssf = ssf;
			const mass_t true_mz = (ssf->peptide.get_mass()+18.0105+(mass_t)charge)/(mass_t)charge;

			init_for_current_spec(config,bs);
			calculate_curr_spec_pmc_values(bs, bin_increment);

			if (0)
			{
				ssf->peptide.calc_mass(config);
				mass_t true_mz = (ssf->peptide.get_mass() + 18.0105 + charge)/charge;
				cout << ssf->peptide.as_string(config) << "  -  true:" << true_mz << "\t" << "obs:" << ssf->m_over_z << endl;
				int j;
				for (j=0; j<curr_spec_pmc_tables[charge].size(); j++)
				{
					const PmcStats& pmc = curr_spec_pmc_tables[charge][j];
					cout << pmc.m_over_z <<"\t" << pmc.num_frag_pairs << "\t" << pmc.num_c2_frag_pairs <<
						"\t" << pmc.tol_frag_pairs << "\t" << pmc.tol_c2_frag_pairs << endl;
				}
				cout << endl;
			}

			vector<ME_Regression_Sample> spec_samples;

			fill_fval_vectors_with_PMC(bs, charge, spec_samples);

			// select samples and add them to pmc_ds
			int good_idx;
			vector<int> bad_idxs;
			select_training_sample_idxs(charge,spec_samples,bs,good_idx,bad_idxs);

			spec_samples[good_idx].label = 0;
			pmc_ds.add_sample(spec_samples[good_idx]);

			int j;
			for (j=0; j<bad_idxs.size(); j++)
			{
				const int bad_idx = bad_idxs[j];
				if (bad_idx < 0 || bad_idx>= spec_samples.size())
					continue;

				spec_samples[bad_idx].label=1;
				pmc_ds.add_sample(spec_samples[bad_idx]);
			}						   
		}
		

		pmc_ds.num_classes=2;
		pmc_ds.num_features=PMC_NUM_FIELDS;
	
		pmc_ds.tally_samples();
		pmc_ds.calibrate_class_weights(0.5);

		cout << endl << "CHARGE " << charge << endl;
		pmc_ds.print_feature_summary(cout,PMC_var_names);

		pmc_models[charge]=new ME_Regression_Model;

		pmc_models[charge]->train_cg(pmc_ds,500);

	//	ds.report_feature_statistics(SQS_MAX_TAG_LENGTH_ABOVE_4,"SQS_MAX_TAG_LENGTH_ABOVE_4");

	//	ds.report_feature_statistics(SQS_MAX_TAG_LENGTH_BELOW_4,"SQS_MAX_TAG_LENGTH_BELOW_4");

	//	sqs_models[charge]->print_ds_histogram(ds);

		pmc_models[charge]->print_ds_probs(pmc_ds);

	
	}

	string path;
	path = config->get_resource_dir() + "/" + config->get_model_name() + "_PMC.txt";
	this->write_pmc_models(path.c_str());
	ind_initialized_pmc = true;
}



/****************************************************************************
Finds the bin which has the optimal values (look for the maximal number of pairs).
Performs search near the peptide's true m/z value to comenstate for systematic bias
in the precursor mass.
*****************************************************************************/
int PMCSQS_Scorer::get_optimal_bin(int true_mz_bin, int charge) const
{
	const int max_bin_offset = 9+2*charge; // look in the range +- of this value
	const vector<PmcStats>& pmc_stats = curr_spec_pmc_tables[charge];
	const int min_bin_idx = (true_mz_bin - max_bin_offset>=0 ? true_mz_bin - max_bin_offset : 0);
	const int max_bin_idx = (true_mz_bin + max_bin_offset>= pmc_stats.size() ? pmc_stats.size()-1 :
								true_mz_bin + max_bin_offset);

	if (pmc_stats[true_mz_bin].num_frag_pairs==0 &&
		pmc_stats[true_mz_bin].num_c2_frag_pairs==0)
		return true_mz_bin;
	
	int   optimal_bin_idx=-1;
	float max_num_pairs=0;
	float best_offset=99999999;

	if (pmc_stats[true_mz_bin].num_frag_pairs>=pmc_stats[true_mz_bin].num_c2_frag_pairs)
	{
		float max_num_pairs=0;
		int bin_idx;
		for (bin_idx = min_bin_idx; bin_idx<=max_bin_idx; bin_idx++)
			if (pmc_stats[bin_idx].num_frag_pairs > max_num_pairs)
				max_num_pairs = pmc_stats[bin_idx].num_frag_pairs;

		// find minimal offset
		for (bin_idx = min_bin_idx; bin_idx<=max_bin_idx; bin_idx++)
			if (pmc_stats[bin_idx].num_frag_pairs == max_num_pairs &&
				pmc_stats[bin_idx].tol_frag_pairs < best_offset)
			{
				optimal_bin_idx = bin_idx;
				best_offset = pmc_stats[bin_idx].tol_frag_pairs;
			}

		return optimal_bin_idx;
		
	}
	else
	// use the charge 2 fragment pairs
	{
		float max_num_pairs=0; 
		int bin_idx;
		for (bin_idx = min_bin_idx; bin_idx<=max_bin_idx; bin_idx++)
			if (pmc_stats[bin_idx].num_c2_frag_pairs > max_num_pairs)
				max_num_pairs = pmc_stats[bin_idx].num_c2_frag_pairs;

		// find minimal offset
		for (bin_idx = min_bin_idx; bin_idx<=max_bin_idx; bin_idx++)
			if (pmc_stats[bin_idx].num_c2_frag_pairs == max_num_pairs &&
				pmc_stats[bin_idx].tol_c2_frag_pairs < best_offset)
			{
				optimal_bin_idx = bin_idx;
				best_offset = pmc_stats[bin_idx].tol_c2_frag_pairs;
			}

		return optimal_bin_idx;	
	}


	return -1;
}


/*********************************************************************************
Takes a set of samples around the correct mass ([-3+5] every 0.1 Da.)
Selects the bin of the correct mass as positive and a set from offseted m/z
as negative samples. 
**********************************************************************************/
void PMCSQS_Scorer::select_training_sample_idxs(
		int charge,
		const vector<ME_Regression_Sample>& spec_samples,
		const BasicSpectrum& bs,
		int& correct_idx,
		vector<int>& bad_pmc_idxs) const
{
	const vector<PmcStats>& pmc_stats = curr_spec_pmc_tables[charge];

	bs.ssf->peptide.calc_mass(config);
	const mass_t pep_mass = bs.ssf->peptide.get_mass()+18.0105;
	const mass_t true_mz = (pep_mass + charge)/charge + this->charge_mz_biases[charge];
	const mass_t observed_mz = bs.ssf->m_over_z;

	// check that the training sample has an ok offset
	if (fabs(true_mz-observed_mz)>6.0)
	{
		
		cout << "Erorr in m/z offsets (remove this spectrum from training set): " << endl;
		cout << fixed << setprecision(2) << "file m/z: " << observed_mz << "\t" << 
			"\"true\" m/z: " << true_mz << "\t peptide: " << bs.ssf->peptide.as_string(config) << endl;
		cout << "spectrum: " << bs.ssf->single_name << endl;
		
		cout << "Mass Cys = " << this->config->get_aa2mass()[Cys] << endl;

		exit(1);
	}

	// find the entry with the correct m/z
	int idx=0;
	while (idx<pmc_stats.size() && pmc_stats[idx].m_over_z<true_mz)
		idx++;

	if (pmc_stats[idx].m_over_z-true_mz>true_mz-pmc_stats[idx-1].m_over_z)
		idx--;

	correct_idx = idx;

	// adjust the correct idx
	if (0)
	{
		const int start_idx = (correct_idx>2 ? correct_idx-2 : 0);
		const int end_idx   = (correct_idx<pmc_stats.size()-2 ? correct_idx+2 : pmc_stats.size()-1);

		if (pmc_stats[correct_idx].num_frag_pairs>=pmc_stats[correct_idx].num_c2_frag_pairs)
		{
			int i;
			for (i=start_idx; i<=end_idx; i++)
				if (pmc_stats[i].ind_pairs_with_min_tol)
					correct_idx = i;
		}
		else
		{
			int i;
			for (i=start_idx; i<=end_idx; i++)
				if (pmc_stats[i].ind_c2_pairs_with_min_tol)
					correct_idx = i;
		}
	}



	bad_pmc_idxs.clear();
	bad_pmc_idxs.push_back(correct_idx+5);
	bad_pmc_idxs.push_back(correct_idx+10);
	bad_pmc_idxs.push_back(correct_idx+20);
	bad_pmc_idxs.push_back(correct_idx+30);
	bad_pmc_idxs.push_back(correct_idx-5);
	bad_pmc_idxs.push_back(correct_idx-10);

	// select upto 5 random samples (make sure they are not close to the correct one)
	int i;
	for (i=0; i<4; i++)
	{
		int idx = (int)(my_random()*pmc_stats.size());
		if (abs(correct_idx-idx)<6)
			continue;

		bad_pmc_idxs.push_back(idx);
	}
}


/*************************************************************************
Tests the performance of precursor mass correction
**************************************************************************/
void PMCSQS_Scorer::test_pmc(Config *config, char *specs_file, int charge)
{
	BasicSpecReader bsr;
	static QCPeak peaks[5000];

	FileManager fm;
	FileSet fs;
		
	fm.init_from_file(config,specs_file);
	fs.select_all_files(fm);

	const int max_to_read_per_file = 1000;

	const vector<SingleSpectrumFile *>& all_ssf = fs.get_ssf_pointers();
	const int num_samples = (all_ssf.size()<max_to_read_per_file ? all_ssf.size() :
									max_to_read_per_file);
	
	vector<mass_t> org_offsets;
	vector<mass_t> corr_offsets;

	vector<int> ssf_idxs;
	if (num_samples<all_ssf.size())
	{
		choose_k_from_n(num_samples,all_ssf.size(),ssf_idxs);
	}
	else
	{
		int i;
		ssf_idxs.resize(all_ssf.size());
		for (i=0; i<all_ssf.size(); i++)
			ssf_idxs[i]=i;
	}
		
	int i;
	for (i=0; i<num_samples; i++)
	{
		SingleSpectrumFile* ssf = all_ssf[ssf_idxs[i]];
		BasicSpectrum bs;
	
		bs.num_peaks = bsr.read_basic_spec(config,fm,ssf,peaks);
		bs.peaks = peaks;
		bs.ssf = ssf;

		init_for_current_spec(config,bs);
		calculate_curr_spec_pmc_values(bs, bin_increment);

		PmcSqsChargeRes res;
		find_best_mz_values(bs, charge, res);

		ssf->peptide.calc_mass(config);
		mass_t true_mz = (ssf->peptide.get_mass() + 18.01 + charge)/charge;

		org_offsets.push_back(true_mz - ssf->m_over_z);
		corr_offsets.push_back(true_mz - res.mz1 + this->charge_mz_biases[charge]);
	}

	mass_t m_org,sd_org,m_corr,sd_corr;
	calc_mean_sd(org_offsets,&m_org, &sd_org);
	calc_mean_sd(corr_offsets,&m_corr,&sd_corr);

	cout << "ORG:  mean " << m_org << " " << sd_org << endl;

	cout << "CORR: mean " << m_corr << " " << sd_corr << endl;

}


/***********************************************************************************

Functions for training set.


************************************************************************************/

struct ScanPair {
	ScanPair(int f,int sc, string& se) : file_idx(f), scan(sc), seq(se) {};
	ScanPair(int f,int s) : file_idx(f), scan(s) {};
	ScanPair() : file_idx(-1), scan(-1) {};

	bool operator< (const ScanPair& other) const
	{
		return (file_idx<other.file_idx || 
			    (file_idx == other.file_idx && scan<other.scan));
	}

	bool operator == (const ScanPair& other) const
	{
		return (file_idx == other.file_idx && scan == other.scan);
	}


	int file_idx;
	int scan;
	string seq;
};

void read_idxs_from_file(char *file, vector<ScanPair>& final_pairs, int max_size)
{
	ifstream inp(file,ios::in);
	
	if (! inp.good())
	{	
		cout << "Error opening: " << file << endl;
		exit(1);
	}

	vector<ScanPair> pairs;
	pairs.clear();

	char buff[256];
	while (inp.getline(buff,256))
	{
		istringstream iss(buff);
		int f,s;
		string seq;

		iss >> f >> s >> seq;

		
		if (f>=0 && s>=0)
		{
			if (seq.length()>2)
			{
				pairs.push_back(ScanPair(f,s,seq));
			}
			else
				pairs.push_back(ScanPair(f,s));
		}


	}
	inp.close();

	if (pairs.size() > max_size)
	{
		vector<int> idxs;
		choose_k_from_n(max_size,pairs.size(),idxs);
		final_pairs.resize(max_size);
		int i;
		for (i=0; i<max_size; i++)
			final_pairs[i]=pairs[idxs[i]];
	}
	else
	{
		final_pairs=pairs;
	}

	sort(final_pairs.begin(),final_pairs.end());
}


void create_training_files(Config *config)
{
	char mzxml_list[]={"C:\\Work\\msms5\\PepNovoHQ\\pmcsqs\\HEK293_mzxml_list.txt"};
	char idxs_neg_file[]={"C:\\Work\\msms5\\PepNovoHQ\\pmcsqs\\H40ul_neg_samples.txt"};
//	char idxs1_file[]={"C:\\Work\\msms5\\PepNovoHQ\\pmcsqs\\H40ul_pos_samples.1.txt"};
//	char idxs2_file[]={"C:\\Work\\msms5\\PepNovoHQ\\pmcsqs\\H40ul_pos_samples.2.txt"};
//	char idxs2_file[]={"C:\\Work\\msms5\\PepNovoHQ\\pmcsqs\\Len10_pos_samples.2.txt"};
	char idxs1_file[]={"C:\\Work\\msms5\\PepNovoHQ\\pmcsqs\\sqs_train_pos_samples.1.txt"};
	char idxs2_file[]={"C:\\Work\\msms5\\PepNovoHQ\\pmcsqs\\sqs_train_pos_samples.2.txt"};
	char idxs3_file[]={"C:\\Work\\msms5\\PepNovoHQ\\pmcsqs\\H40ul_pos_samples.3.txt"};

	char out_base[]={"C:\\Work\\msms5\\PepNovoHQ\\pmcsqs\\sqs_train"};
	string out_neg (out_base); 
	string out1=out_neg;
	string out2=out_neg;
	string out3=out_neg;

	out_neg += "_neg.mgf";
	out1 += "_1.mgf";
	out2 += "_2.mgf";
	out3 += "_3.mgf";

	ofstream stream_neg (out_neg.c_str(),ios::out);
	ofstream stream1(out1.c_str(),ios::out);
	ofstream stream2(out2.c_str(),ios::out);
	ofstream stream3(out3.c_str(),ios::out);

	vector<ScanPair> neg_pairs, pairs1,pairs2,pairs3;


	read_idxs_from_file(idxs_neg_file,neg_pairs,12000);
	read_idxs_from_file(idxs1_file,pairs1,12000);
	read_idxs_from_file(idxs2_file,pairs2,12000);
	read_idxs_from_file(idxs3_file,pairs3,8000);

	cout << "Read " << neg_pairs.size() << " neg idxs\n";
	cout << "Read " << pairs1.size() << " pos1 idxs\n";
	cout << "Read " << pairs2.size() << " pos2 idxs\n";
	cout << "Read " << pairs3.size() << " pos3 idxs\n";

	vector<bool> file_inds;
	file_inds.resize(10000,false);
	int i;

	for (i=0; i<neg_pairs.size(); i++)
		file_inds[neg_pairs[i].file_idx]=true;

	for (i=0; i<pairs1.size(); i++)
		file_inds[pairs1[i].file_idx]=true;

	for (i=0; i<pairs2.size(); i++)
		file_inds[pairs2[i].file_idx]=true;

	for (i=0; i<pairs3.size(); i++)
		file_inds[pairs3[i].file_idx]=true;

	
	FileManager fm;
	FileSet fs;

	fm.init_from_list_file(config,mzxml_list,file_inds);
	fs.select_all_files(fm);
	const vector<SingleSpectrumFile *>& all_ssf = fs.get_ssf_pointers();

	

	// read spectra
	BasicSpecReader bsr;
	QCPeak peaks[5000];

	int num_out_neg=0, num_out1=0, num_out2=0, num_out3=0;
	int neg_idx=0,c1_idx=0,c2_idx=0,c3_idx=0;


	for (i=0; i<all_ssf.size(); i++)
	{
		MZXML_single *ssf = (MZXML_single *)all_ssf[i];
		ScanPair ssf_pair(ssf->file_idx,ssf->scan_number);
		string seq="";

		int out_dest=-1;

		while (neg_idx<neg_pairs.size() && neg_pairs[neg_idx]<ssf_pair)
			neg_idx++;

		if (neg_idx<neg_pairs.size() && neg_pairs[neg_idx]==ssf_pair)
			out_dest=0;


		while (c1_idx<pairs1.size() && pairs1[c1_idx]<ssf_pair)
			c1_idx++;
		if (c1_idx<pairs1.size() && pairs1[c1_idx]==ssf_pair)
		{
			seq = pairs1[c1_idx].seq;
			out_dest=1;
		}


		while (c2_idx<pairs2.size() && pairs2[c2_idx]<ssf_pair)
			c2_idx++;
		if (c2_idx<pairs2.size() && pairs2[c2_idx]==ssf_pair)
		{
			seq = pairs2[c2_idx].seq;
			out_dest=2;
		}


		while (c3_idx<pairs3.size() && pairs3[c3_idx]<ssf_pair)
			c3_idx++;
		if (c3_idx<pairs3.size() && pairs3[c3_idx]==ssf_pair)
		{
			seq = pairs3[c3_idx].seq;
			out_dest=3;
		}

		if (out_dest<0)
			continue;

		BasicSpectrum bs;
		bs.num_peaks = bsr.read_basic_spec(config,fm,ssf,peaks);
		bs.peaks = peaks;
		bs.ssf = ssf;

	//	if (out_dest>0)
	//		bs.ssf->peptide.parse_from_string(config,seq);
	
		char name_buff[64];
		if (out_dest==0)
		{
			sprintf(name_buff,"train_neg_%d_%d_%d",num_out_neg,ssf->file_idx,ssf->scan_number);
			bs.ssf->single_name = string(name_buff);
			bs.output_to_mgf(stream_neg,config);
			num_out_neg++;
			continue;
		}

		if (out_dest==1)
		{
			sprintf(name_buff,"train_pos1_%d_%d_%d",num_out1,ssf->file_idx,ssf->scan_number);
			bs.ssf->single_name = string(name_buff);
			bs.output_to_mgf(stream1,config,seq.c_str());
			num_out1++;
			continue;
		}

		if (out_dest==2)
		{
			sprintf(name_buff,"train_pos2_%d_%d_%d",num_out2,ssf->file_idx,ssf->scan_number);
			bs.ssf->single_name = string(name_buff);
			bs.output_to_mgf(stream2,config,seq.c_str());
			num_out2++;
			continue;
		}

		if (out_dest==3)
		{
			sprintf(name_buff,"train_pos3_%d_%d_%d",num_out3,ssf->file_idx,ssf->scan_number);
			bs.ssf->single_name = string(name_buff);
			bs.output_to_mgf(stream3,config,seq.c_str());
			num_out3++;
			continue;
		}
	}

	cout << "Wrote: " << endl;
	cout << "Neg " << num_out_neg << " / " << neg_pairs.size() << endl;
	cout << "Pos1 " << num_out1 << " / " << pairs1.size() << endl;
	cout << "Pos2 " << num_out2 << " / " << pairs2.size() << endl;
	cout << "Pos3 " << num_out3 << " / " << pairs3.size() << endl;

	stream_neg.close();
	stream1.close();
	stream2.close();
	stream3.close();
	
}





void PMCSQS_Scorer::print_spec(const BasicSpectrum& bs) const
{
	cout << bs.ssf->single_name << endl;
	int i;
	for (i=0; i<bs.num_peaks; i++)
	{
		cout << setprecision(2) << fixed << bs.peaks[i].mass << "\t" << bs.peaks[i].intensity << "\t";
		if (curr_spec_iso_levels[i]>0)
			cout << " ISO ";
		if (curr_spec_strong_inds[i])
			cout << " STRONG ";
		cout << endl;
	}
}





