#include "Model.h"
#include "FragmentSelection.h"


// reads a model and all relevant files
// the model files are assumed to be in the resource_dir
// all this model's files are assumed to have a name <model_name>_XXXXX.txt
// the main model file is <model_name>.txt
void Model::read_model(const char* name)
{
	char file[256];

	model_name = name;

	if (config.get_resource_dir().length()<2)
	{
		config.set_resource_dir("Models");
	}

	config.set_model_name(string(name));

	strcpy(file,config.get_resource_dir().c_str());
	strcat(file,"/");
	strcat(file,name); 
	strcat(file,".txt");   

	fstream fs(file,ios::in);
	if (! fs.good() )  
	{
		cout << "Error: couldn't open model file: " << file << endl;
		exit(1);
	}

	while (! fs.eof())
	{
		char buff[1024];
		fs.getline(buff,1024);
		if (fs.gcount()<4)
			continue;

		char arg[128];
		if (sscanf(buff,"#CONFIG_FILE %s",arg) == 1)
		{
			config.read_config(arg);
			config.set_model_name(string(model_name));
			continue;
		}

		if (! strncmp("#CONF",buff,5))
		{
			string path = config.get_resource_dir() + "/" + string(buff);
			config.parse_config_parameter((char *)path.c_str());
			continue;
		}

		if (sscanf(buff,"#BREAK_SCORE_MODEL %s",arg) ==1)
		{
			string path = config.get_resource_dir() + "/" + string(arg);
			ifstream is(path.c_str(),ios::in);
			read_score_model(is);
			continue;
		}

		if (sscanf(buff,"#SQS_MODEL %s",arg) == 1)
		{
			pmcsqs.read_sqs_models(&config,arg);
			continue;
		}

		if (sscanf(buff,"#PMC_MODEL %s",arg) == 1)
		{
			pmcsqs.read_pmc_models(&config,arg);
			continue;
		}
	}

	// check if some of the defaults need to be changed
	if (config.get_max_edge_length() != 2)
		config.calc_aa_combo_masses();

}



// writes a model and all relevant files
// the model files are assumed to be in the resource_dir
// all this model's files are assumed to have a name <model_name>_XXXXX.txt
// the main model file is <model_name>.txt
void Model::write_model()
{
	string model_file;

	model_file = config.get_resource_dir() + "/" + model_name + ".txt";

	fstream os(model_file.c_str(),ios::out);
	if ( ! os.good())
	{
		cout << "Error writing model to " << model_file << endl;
		exit(1);
	}


	string config_file = config.get_resource_dir() + "/" + model_name + "_config.txt";
	config.set_config_file(config_file);
	config.set_model_name(model_name);
	os << "#CONFIG_FILE " << model_name + "_config.txt" << endl;
	config.write_config();

	string score_model = config.get_resource_dir() + "/" + model_name + "_break_score.txt";
	os << "#BREAK_SCORE_MODEL " << model_name + "_break_score.txt" << endl;

	if (pmcsqs.get_ind_initialized_pmc())
	{
		os << "#PMC_MODEL " << model_name + "_PMC.txt" << endl;
		string path = config.get_resource_dir() + "/" + model_name + "_PMC.txt";
		pmcsqs.write_pmc_models(path.c_str());
	}

	if (pmcsqs.get_ind_initialized_sqs())
	{
		os << "#SQS_MODEL " << model_name + "_SQS.txt" << endl;
		string path = config.get_resource_dir() + "/" + model_name + "_SQS.txt";
		pmcsqs.write_sqs_models(path.c_str());
	}

	fstream sm(score_model.c_str(),ios::out);
	write_score_model(sm);
}


/*********************************************************************
// this function performs the entire training process of the model
**********************************************************************/
void Model::full_train_model(const char *name, const FileManager& fm, 
							 mass_t initial_tolerance)
{
	config.set_tolerances(initial_tolerance);

	model_name = name;
	config.set_model_name(string(name));

	int charge;
	for (charge = fm.get_min_charge(); charge<= fm.get_max_charge(); charge++)
	{
		vector<mass_t> spectra_masses;
		FileSet fs;
		fs.select_all_files(fm);
		const vector<SingleSpectrumFile *>& all_ssf = fs.get_ssf_pointers();
		int i;
		for (i=0; i<all_ssf.size(); i++)
			if (all_ssf[i]->charge == charge)
				spectra_masses.push_back(all_ssf[i]->org_pm_with_19);

		config.set_size_thresholds_according_to_set_of_masses(charge,spectra_masses);
	}


	cout << "Selecting fragments ..." << endl;
	select_fragments(name,fm,15,0.05);


	// find the tolerance and precursor mass tolerance

	cout << "Calculating precursor mass tolerance..." << endl;

	float pm_tol = calc_parent_mass_tolerance_distribution(this, fm, 0.95);

	cout << "Calculating fragment mass tolerance..." << endl;

	float tol    = calc_tolerance_distribution(this, fm , initial_tolerance*1.2,0.96);


	config.set_pm_tolerance(pm_tol);

	if (pm_tol <0.000001)
	{
		pm_tol = tol;
	}

	if (pm_tol<tol)
	{
		config.set_tolerance(tol+pm_tol);
	}
	else
		config.set_tolerance(tol);

	if (config.get_tolerance()<0.03)
	{
		config.set_terminal_score(12.0);
		config.set_max_edge_length(3);
	}
	else if (config.get_tolerance()<0.1)
	{
		config.set_terminal_score(8.0);
		config.set_max_edge_length(2);
	}
	else
	{
		config.set_terminal_score(5.0);
		config.set_max_edge_length(2);
	}

	write_model();

	init_score_model();

	train_score_model(name,fm,0);

	write_model();
}



/******************************************************************************
	This model selects the fragment types that will take part in the models.
	The fragments are selected according to the offset frequency function.
	If there isn't a suffcient number of spectra from the desired charge,
	the fragment selection is skipped.
*******************************************************************************/
bool Model::select_fragments(const char *name, const FileManager& fm, 
							 int max_num_frags, float min_prob)
{
	FragmentTypeSet fragment_types;

	int c;
	cout << "Training set consists of:" << endl;
	for (c=fm.get_min_charge(); c<=fm.get_max_charge(); c++)
		cout << "Charge " << c <<"  " << fm.get_num_spectra(c) << " spectra."<< endl;
	cout<<endl;

	// select potential fragment type using the fragment offset test
	select_frags_using_frag_offset_counts(fm,&config,fragment_types, min_prob);

	// add these fragments to the existing set
	config.add_fragment_types(fragment_types);

	
	for (c=1; c<=fm.get_max_charge(); c++)
	{
	
		// check that there is a minimal number of files...
		int num_charge_spectra = fm.get_num_spectra(c);

		if (num_charge_spectra<MINIMAL_NUMBER_SPECTRA_FOR_FRAGMENT_SELECTION)
			continue;

		config.init_regional_fragment_set_defaults(0,c);

		select_regional_fragments(fm,&config,c,true);

		config.select_fragments_in_sets(1.0,max_num_frags);

		// select strong, combos...
		int max_num_combos = max_num_frags > 0 ? max_num_frags : 2;
		if (max_num_combos>3)
			max_num_combos = 3;
		
		config.select_strong_fragments(c,0.5,3);
		select_frag_combos(fm,&config,c,max_num_combos);
	}

	string fragments_file = config.get_resource_dir() + "/" + string(name) + "_fragments.txt";
	ofstream os(fragments_file.c_str(),ios::out);
	config.print_fragments(os);
	config.set_fragments_file(fragments_file);
	os.close();

	string regional_fragment_sets_file = config.get_resource_dir() + "/" + string(name) + "_fragment_sets.txt";
	os.open(regional_fragment_sets_file.c_str(),ios::out);
	config.print_regional_fragment_sets(os);
	config.set_regional_fragment_sets_file(regional_fragment_sets_file);
	os.close();

	return true;
}




// determines the tolerance for which *cuttoff_prob* of the abundant fragments
// are caught
mass_t calc_tolerance_distribution(Model *model, const FileManager& fm, mass_t max_tolerance,
								   float cutoff_prob)
{
	FileSet fs;
	Config *config = model->get_config();
	FragmentTypeSet frags;
	vector<string> file_list;

	fs.select_all_files(fm);

	vector<int> test_frag_idxs;
	test_frag_idxs.clear();
	
	if (config->get_strong_type1_idx()>=0)
		test_frag_idxs.push_back(config->get_strong_type1_idx());

	if (config->get_strong_type2_idx()>=0)
		test_frag_idxs.push_back(config->get_strong_type2_idx());

	if (test_frag_idxs.size()==0)
	{
		cout << endl <<"Warning: no strong fragments selected, using maximal tolerance!!!" << endl << endl;
		return max_tolerance;
	}

	mass_t tol_increment = max_tolerance * 0.05;
	vector<float> offset_bins;
	vector<float> offsets;

	offset_bins.resize(41,0);

	int total_frag_count=0;
	while(1)
	{
		Spectrum s;
		vector<mass_t> break_masses;
		mass_t true_mass_with_19,true_mass;
	
		if (! fs.get_next_spectrum(fm,config,&s))
			break;
		
		s.init_spectrum();
		s.get_peptide().calc_expected_breakage_masses(config,break_masses);
		true_mass=s.get_peptide().get_mass();
		true_mass_with_19 =  true_mass + 19.0183;

		if (break_masses.size()<3)
			continue;		
	
		// loop on fragments first, so high count fragments get precedence over
		// low count fragments that are actually due to b/y ions of previous or
		// next amino acids
		int f;
		for (f=0; f<test_frag_idxs.size(); f++)
		{
			const FragmentType& frag = config->get_fragment(test_frag_idxs[f]);
			int b;

			
			for (b=1; b<break_masses.size()-1; b++)
			{
				mass_t break_mass = break_masses[b];

				const mass_t exp_mass = frag.calc_expected_mass(break_mass,true_mass_with_19);
				const int p_idx = s.get_max_inten_peak(exp_mass,max_tolerance);

				if (p_idx>=0)
				{
					
					total_frag_count++;
					mass_t offset =  s.get_peak_mass(p_idx) - exp_mass;

					int bin_idx = 20 + (int)((offset / max_tolerance)*20);
					if (bin_idx<0)
						bin_idx=0;
					if (bin_idx>40)
						bin_idx=40;

					offset_bins[bin_idx]++;
					offsets.push_back(offset);
				}
			}
		}
	}

	int i;
	cout << "bin histogram: " << endl;
	for (i=0; i<=40; i++)
		cout << setprecision(4) << (20-i)*tol_increment << " " << 
			    offset_bins[i]/total_frag_count << endl;

	// find the offset that keeps the desired proportion of fragments
	sort(offsets.begin(),offsets.end());
	int count=0;
	int target_count = (int)((1.0 - cutoff_prob)*total_frag_count);
	int left_idx=0;
	int right_idx=offsets.size()-1;
	mass_t cutoff_offset=-1;
	while (count<target_count)
	{
		if (fabs(offsets[left_idx])>offsets[right_idx])
		{
			left_idx++;
		}
		else
			right_idx--;

		if (++count == target_count)
		{
			if (fabs(offsets[left_idx])>fabs(offsets[right_idx]))
			{
				cutoff_offset = fabs(offsets[left_idx]); 
			}
			else
				cutoff_offset = fabs(offsets[right_idx]);

			break;
		}
	}

	cout << "offset for " << cutoff_prob << " is " << cutoff_offset << endl;
	return cutoff_offset;
	exit(0);
}


// determines the parent mass tolerance for which *cuttoff_prob* of the abundant fragments
// are caught
mass_t calc_parent_mass_tolerance_distribution(Model *model,  const FileManager& fm, 
											   float cutoff_prob)
{
	FileSet fs;
	Config *config = model->get_config();
	FragmentTypeSet frags;
	
	fs.select_all_files(fm);

	vector<float> offsets;

	int total_frag_count=0;
	while(1)
	{
		Spectrum s;
		vector<mass_t> break_masses;
		mass_t true_mass_with_19,true_mass;
	
		if (! fs.get_next_spectrum(fm,config,&s))
			break;
		
		s.init_spectrum();
		s.get_peptide().calc_expected_breakage_masses(config,break_masses);
		true_mass=s.get_peptide().get_mass();
		true_mass_with_19 =  true_mass + 19.0183;

		if (break_masses.size()<3)
			continue;		
				
		total_frag_count++;

		mass_t offset =  s.get_org_pm_with_19() - true_mass_with_19;
		cout << setprecision(3) << offset << " " << s.get_charge() << " " << s.get_peptide().as_string(config) << endl;
		offsets.push_back(offset);
	}



	// find the offset that keeps the desired proportion of fragments
	sort(offsets.begin(),offsets.end());
	int count=0;
	int target_count = (int)((1.0 - cutoff_prob)*total_frag_count);
	int left_idx=0;
	int right_idx=offsets.size()-1;
	mass_t cutoff_offset=-1;
	while (count<target_count)
	{
		if (fabs(offsets[left_idx])>offsets[right_idx])
		{
			left_idx++;
		}
		else
			right_idx--;

		if (++count == target_count)
		{
			if (fabs(offsets[left_idx])>fabs(offsets[right_idx]))
			{
				cutoff_offset = fabs(offsets[left_idx]); 
			}
			else
				cutoff_offset = fabs(offsets[right_idx]);

			break;
		}
	}

	cout << "Parent mass offset for " << setprecision(4) << cutoff_prob << " is " << cutoff_offset << endl;
	return cutoff_offset;
}







