/*****************************************************
 Amino Acid Preference Toolkit in Java
 Pathogen Project
 Department of Computer Science and Engineering
 University of South Carolina
 Columbia, SC 29208
 Contact Email: rose@cse.sc.edu
*****************************************************/

import java.io.*;
import java.util.*;

public class AminoAcidTriplePreference
{

	Hashtable aminoToIndexMap;
	
	public static void main(String[] args)
	{

		// args[0] = amino acid filename (FASTA format)
		// args[1] = output filename

		String aminoFilename = args[0];
		String outputFilename = args[1];

		SequenceReader sequenceReader = new SequenceReader();
		char[] aminoAcidSequence = sequenceReader.parseFastaFile(aminoFilename).toCharArray();
		AminoAcidTriplePreference triplePreference = new AminoAcidTriplePreference();
		double[] triplePreferenceVector = triplePreference.calculateTriplePreferenceVector(aminoAcidSequence);
		writeToFile(triplePreferenceVector, outputFilename);
		
	}


	public static void writeToFile(double[] triplePreferenceVector, String outputFilename)
	{
		try
		{
			// writes in format needed for classification tools, HMM learning, etc
			PrintWriter printWriter = new PrintWriter(new FileWriter(outputFilename));
			for (int i = 0; i < triplePreferenceVector.length; i++)	
			{
				if (i == (triplePreferenceVector.length - 1)) printWriter.println(triplePreferenceVector[i]);
				else printWriter.print(triplePreferenceVector[i] + " ");
			}
			printWriter.flush();	
			printWriter.close();
		}
		catch (Exception exception)
		{	
			exception.printStackTrace();
		}

	}

	public AminoAcidTriplePreference()
	{
		aminoToIndexMap = new Hashtable();
		aminoToIndexMap.put("A", new Integer(0));
		aminoToIndexMap.put("C", new Integer(1));
		aminoToIndexMap.put("D", new Integer(2));
		aminoToIndexMap.put("E", new Integer(3));
		aminoToIndexMap.put("F", new Integer(4));
		aminoToIndexMap.put("G", new Integer(5));
		aminoToIndexMap.put("H", new Integer(6));
		aminoToIndexMap.put("I", new Integer(7));
		aminoToIndexMap.put("K", new Integer(8));
		aminoToIndexMap.put("L", new Integer(9));
		aminoToIndexMap.put("M", new Integer(10));
		aminoToIndexMap.put("N", new Integer(11));
		aminoToIndexMap.put("P", new Integer(12));
		aminoToIndexMap.put("Q", new Integer(13));
		aminoToIndexMap.put("R", new Integer(14));
		aminoToIndexMap.put("S", new Integer(15));
		aminoToIndexMap.put("T", new Integer(16));
		aminoToIndexMap.put("V", new Integer(17));
		aminoToIndexMap.put("W", new Integer(18));
		aminoToIndexMap.put("Y", new Integer(19));
		aminoToIndexMap.put("Z", new Integer(20));
		
	}


	double[] calculateTriplePreferenceVector(char[] aminoAcidSequence)
	{

		char firstAminoAcid, secondAminoAcid, thirdAminoAcid;
		int firstAminoAcidIndex, secondAminoAcidIndex, thirdAminoAcidIndex;
		int i,j,k;

		double sum = 0;
		double vSum = 0;
		double[] returnVector = new double[8000];
		double[][][] nextTo = new double[20][20][20];

		for (i = 0; i < 20; i++)
		{
			for (j = 0; j < 20; j++)
			{
				for (k = 0; k < 20; k++)
				{
					nextTo[i][j][k] = 0.0;
				}
			}
		}
		
		// go through the sequence looking at triples 
		int aminoAcidSequenceLength = aminoAcidSequence.length - 2;
		for (i = 0; i < aminoAcidSequenceLength; i++)
		{
			firstAminoAcid = aminoAcidSequence[i];
			secondAminoAcid = aminoAcidSequence[i+1];
			thirdAminoAcid = aminoAcidSequence[i+2];
			firstAminoAcidIndex = ((Integer)aminoToIndexMap.get(""+firstAminoAcid)).intValue();
			secondAminoAcidIndex = ((Integer)aminoToIndexMap.get(""+secondAminoAcid)).intValue();
			thirdAminoAcidIndex = ((Integer)aminoToIndexMap.get(""+thirdAminoAcid)).intValue();

			if ((!(firstAminoAcidIndex == 20)) && (!(secondAminoAcidIndex == 20)) && (!(thirdAminoAcidIndex == 20)))
			{
				nextTo[firstAminoAcidIndex][secondAminoAcidIndex][thirdAminoAcidIndex] = nextTo[firstAminoAcidIndex][secondAminoAcidIndex][thirdAminoAcidIndex] + 1;
				sum = sum + 1;
			}
		}

		// normalize

		for (i = 0; i < 20; i++)
		{
			for (j = 0; j < 20; j++)
			{
				for (k = 0; k < 20; k++)
				{
					nextTo[i][j][k] = nextTo[i][j][k] / sum;
				}
			}
		}
		
		// verify sum = 1

		for (i = 0; i < 20; i++)
		{
			for (j = 0; j < 20; j++)
			{
				for (k = 0; k < 20; k++)
				{
					vSum = vSum + nextTo[i][j][k];
				}
			}
		}
		System.out.println("Verification Sum (should = 1): " + vSum);


		for (i = 0; i < 20; i++)
		{
			for (j = 0; j < 20; j++)
			{
				for(k = 0; k < 20; k++)
				{
					returnVector[((400*i) + (20*j)+k)] = nextTo[i][j][k];
				}
			}
		}

		return returnVector;
	}

	         
	public double[] calculateTriplePreferenceVectorNoNormalization(char[] aminoAcidSequence)
	{

		char firstAminoAcid, secondAminoAcid, thirdAminoAcid;
		int firstAminoAcidIndex, secondAminoAcidIndex, thirdAminoAcidIndex;
		int i,j,k;

		double[] returnVector = new double[8000];
		double[][][] nextTo = new double[20][20][20];

		for (i = 0; i < 20; i++)
		{
			for (j = 0; j < 20; j++)
			{
				for (k = 0; k < 20; k++)
				{
					nextTo[i][j][k] = 0.0;
				}
			}
		}
		
		// go through the sequence looking at triples
		int aminoAcidSequenceLength = aminoAcidSequence.length - 2;
		for (i = 0; i < aminoAcidSequenceLength; i++)
		{
			firstAminoAcid = aminoAcidSequence[i];
			secondAminoAcid = aminoAcidSequence[i+1];
			thirdAminoAcid = aminoAcidSequence[i+2];
			firstAminoAcidIndex = ((Integer)aminoToIndexMap.get(""+firstAminoAcid)).intValue();
			secondAminoAcidIndex = ((Integer)aminoToIndexMap.get(""+secondAminoAcid)).intValue();
			thirdAminoAcidIndex = ((Integer)aminoToIndexMap.get(""+thirdAminoAcid)).intValue();

			if ((!(firstAminoAcidIndex == 20)) && (!(secondAminoAcidIndex == 20)) && (!(thirdAminoAcidIndex == 20)))
			{
				nextTo[firstAminoAcidIndex][secondAminoAcidIndex][thirdAminoAcidIndex] = nextTo[firstAminoAcidIndex][secondAminoAcidIndex][thirdAminoAcidIndex] + 1;
			}
		}


		for (i = 0; i < 20; i++)
		{
			for (j = 0; j < 20; j++)
			{
				for (k = 0; k < 20; k++)
				{
					returnVector[((400*i) + (20*j)+k)] = nextTo[i][j][k];
				}
			}
		}

		return returnVector;
	}
}
