/*****************************************************
 Amino Acid Preference Toolkit in Java
 Pathogen Project
 Department of Computer Science and Engineering
 University of South Carolina
 Columbia, SC 29208
 Contact Email: rose@cse.sc.edu
*****************************************************/

import java.io.*;
import java.util.*;

public class ValidationSelectTrainingSets
{
	public static void main(String[] args) throws Exception
	{
		// args[0] = file with links to type files
		// args[1] = number to leave out or percentage

		// args[2] = 0 => proportioned data, 1 => uniform selection 
		// with smallest size set

		// args[3] = size of smallest data set 

                // ============================================


		BufferedReader linkReader;
		BufferedReader typeReader;
		String typeFile;
		Vector typeFilenameVector;
		Hashtable typeTable;
		String line;
		StringTokenizer tokenizer;
		int numberToLeaveOut;
		double percentageToLeaveOut;
		int numberOfTypes;
		int index;
		Random random;
		int randomIndex;

		int selectionType;
		int smallestSet;



		random = new Random();


		// read in all the possible types

		typeFilenameVector = new Vector();
		
		linkReader = new BufferedReader(new FileReader(args[0]));
		while ((line = linkReader.readLine()) != null)
		{

			typeFilenameVector.add(line);
		}		
		linkReader.close();

		percentageToLeaveOut = new Double(args[1]).doubleValue();
		selectionType = Integer.parseInt(args[2]);
		smallestSet = Integer.parseInt(args[3]);


		// now typeFilenameVector has all of the type files I need
		// store all the sequences of that type

		numberOfTypes = typeFilenameVector.size();
		typeTable = new Hashtable();
		for (index = 0; index < numberOfTypes; index++)
		{
			typeFile = (String)typeFilenameVector.elementAt(index);
			typeReader = new BufferedReader(new FileReader(typeFile));
			Vector typeVector = new Vector();
			while ((line = typeReader.readLine()) != null)
			{
				typeVector.add(line);
			}	
			typeReader.close();
			typeTable.put(typeFile, typeVector);
			System.out.println("Read " + typeVector.size() + " entries for type " + typeFile);
		}
	

		// now for each type, I have a list of all the files I can use to generate training/test data
		// lets go through each type and pick out the right amount
	
		int numberOfChunks = (int)(1 / percentageToLeaveOut);
		System.out.println("Generating " + numberOfChunks + " data sets");
		Hashtable globalChunkTable = new Hashtable();
		Hashtable leftoverTable = new Hashtable();

		Enumeration typeEnumeration = typeTable.keys();
		while (typeEnumeration.hasMoreElements())
		{
			String typeName = (String)typeEnumeration.nextElement();			
			Vector typeVector = (Vector)typeTable.get(typeName);
			Vector typeVectorCopy = new Vector();
			
			int numberForType = typeVector.size();
			for (index = 0; index < numberForType; index++)
			{
				typeVectorCopy.add((String)typeVector.elementAt(index));
			}

			int chunkSize = 0;
			if (selectionType == 0)
			{
				chunkSize = (int)(numberForType * percentageToLeaveOut);
			}
			else
			{
				chunkSize = (int)(smallestSet * percentageToLeaveOut); 
			}
			
			Hashtable chunkTable = new Hashtable();
			for (index = 0; index < numberOfChunks; index++)
			{
				Vector chunkVector = new Vector();
				for (int j =0; j < chunkSize; j++)
				{
					if (typeVector.size() > 0)
					{
						numberForType = typeVector.size();
						randomIndex = random.nextInt(numberForType);
						chunkVector.add((String)typeVector.elementAt(randomIndex));
						typeVector.removeElementAt(randomIndex);
					}
				}
				//System.out.println("For type " + typeName + " chunk " + index + " has " + chunkVector.size() + " entries.");
				//System.out.println("Storing a chunk with index: " + index);
				chunkTable.put(""+index, chunkVector);
			}
			

			globalChunkTable.put(typeName, chunkTable);
			typeTable.put(typeName, typeVectorCopy);
			leftoverTable.put(typeName, typeVector);
		}

		// take the leftovers and distribute

		typeEnumeration = leftoverTable.keys();
		while (typeEnumeration.hasMoreElements())
		{
				
			String typeName = (String)typeEnumeration.nextElement();			
			System.out.println("Working on leftovers for: " +typeName);
			Vector typeVector = (Vector)leftoverTable.get(typeName);
			Hashtable chunkTable = (Hashtable)globalChunkTable.get(typeName);

			int numberForType = typeVector.size();
			System.out.println("There are " + numberForType + " leftovers");
			for (int z = 0; z < numberForType; z++)
			{
				Vector chunkVector = (Vector)chunkTable.get(""+z);
				if (chunkVector == null) System.out.println("Chunk vector null");
				String entry = (String)typeVector.elementAt(z);
				chunkVector.add(entry);
				chunkTable.put(""+z, chunkVector);	
			}	  			
			globalChunkTable.put(typeName, chunkTable);
		}

		// now actually generate the testing/training data files
		PrintWriter trainingFormatScriptWriter = new PrintWriter(new FileWriter("generateTrainingSVM.sh"));
		PrintWriter testingFormatScriptWriter = new PrintWriter(new FileWriter("generateTestingSVM.sh"));
		PrintWriter trainingScriptWriter = new PrintWriter(new FileWriter("trainSVM.sh"));
		PrintWriter testingScriptWriter = new PrintWriter(new FileWriter("testSVM.sh"));
		PrintWriter analysisScriptWriter = new PrintWriter(new FileWriter("analyzeSVM.sh"));
		PrintWriter analysisInputWriter = new PrintWriter(new FileWriter("analysisInput"));
		for (index = 0; index < numberOfChunks; index++)
		{
			trainingFormatScriptWriter.println("echo Formatting Training Set " + index);
			trainingFormatScriptWriter.println("java ValidationFormatForSVM "+ args[0] +" trainingData.group" + index + " testingGroup"+index+"/svmTrainingGroup"+index);
			trainingFormatScriptWriter.println("gzip testingGroup"+index+"/svmTrainingGroup*");
			testingFormatScriptWriter.println("echo Formatting Testing Set " + index);
			testingFormatScriptWriter.println("java ValidationFormatFileForSVMTesting testingData.group" + index + " testingGroup"+index+"/svmTestingGroup"+index); 
			testingFormatScriptWriter.println("gzip testingGroup"+index+"/svmTestingGroup"+index);

		
			String testingDataName = "testingData.group"+index;
			PrintWriter trainingWriter = new PrintWriter(new FileWriter("trainingData.group"+index));
			PrintWriter testingWriter = new PrintWriter(new FileWriter(testingDataName));
			PrintWriter testingTypeWriter = new PrintWriter(new FileWriter(testingDataName+".trueTypes"));
			PrintWriter combinationInputWriter = new PrintWriter(new FileWriter("testingGroup"+index+"/combinationInput"));

			int totalTesting = 0;

			trainingScriptWriter.println("echo Training With Group " + index);
			trainingScriptWriter.println("gunzip testingGroup"+index+"/svmTrainingGroup*.gz");
			testingScriptWriter.println("echo Testing With Group " + index);
			testingScriptWriter.println("gunzip testingGroup"+index+"/svmTrainingGroup*.linearModel.gz");
			testingScriptWriter.println("gunzip testingGroup"+index+"/svmTestingGroup*.gz");
			typeEnumeration = typeTable.keys();
			while (typeEnumeration.hasMoreElements())
			{
				String typeName = (String)typeEnumeration.nextElement();
				combinationInputWriter.println(typeName+":testingGroup"+index+"/svmTestingGroup"+index+"." + typeName+".output.linearModel");
				trainingScriptWriter.println("/research/pathogen/svmLight/svm_learn testingGroup"+index+"/svmTrainingGroup"+index +"." + typeName +  " testingGroup"+index+"/svmTrainingGroup"+index+"."+typeName+".linearModel");
				testingScriptWriter.println("/research/pathogen/svmLight/svm_classify testingGroup" + index+"/svmTestingGroup"+index+" testingGroup"+index+"/svmTrainingGroup"+index+"." +typeName+".linearModel testingGroup"+index+"/svmTestingGroup" + index+"." +typeName+".output.linearModel");

				Hashtable chunkTable = (Hashtable)globalChunkTable.get(typeName);
				Vector chunkVector = (Vector)chunkTable.get(""+index);
				Vector allForTypeVector = (Vector)typeTable.get(typeName);

				int allForTypeSize = allForTypeVector.size();
				
				for (int j = 0; j < allForTypeSize; j++)
				{	
					String filename = (String)allForTypeVector.elementAt(j);
					if (chunkVector.contains(filename))
					{	
						testingWriter.println(filename);
						testingTypeWriter.println(filename+":"+typeName);
						totalTesting = totalTesting + 1;
					}
					else
					{
						trainingWriter.println(filename + ":" + typeName);
					}
				}
			
			}
			trainingScriptWriter.println("gzip testingGroup"+index+"/svmTrainingGroup*");
			testingScriptWriter.println("gzip testingGroup"+index+"/svmTrainingGroup*.linearModel");
			testingScriptWriter.println("gzip testingGroup"+index+"/svmTestingGroup"+index);

			trainingWriter.flush();
			trainingWriter.close();
			testingWriter.flush();
			testingWriter.close();
			testingTypeWriter.flush();
			testingTypeWriter.close();
			combinationInputWriter.flush();
			combinationInputWriter.close();
			analysisScriptWriter.println("java ValidationCombineSVMResults testingGroup"+index+"/combinationInput testingGroup"+index+"/combinedSVMOutput " + totalTesting);
			analysisScriptWriter.println("java AnalyzeClassificationResults svmTypeList testingData.group"+index+".trueTypes testingGroup"+index+"/combinedSVMOutput " + typeTable.size() + " > testingGroup"+index+"/AnalysisResults"); 
			analysisInputWriter.println("testingGroup"+index+"/AnalysisResults");
		}
		trainingFormatScriptWriter.flush();
		trainingFormatScriptWriter.close();
		testingFormatScriptWriter.flush();
		testingFormatScriptWriter.close();
		trainingScriptWriter.flush();
		trainingScriptWriter.close();
		testingScriptWriter.flush();
		testingScriptWriter.close();
		analysisScriptWriter.flush();
		analysisScriptWriter.close();
		analysisInputWriter.flush();
		analysisInputWriter.close();
		
	}
}
