variance 공식 test 본문

Programming/Java

variance 공식 test

halatha 2011. 4. 5. 06:37
import it.unimi.dsi.fastutil.doubles.Double2IntOpenHashMap;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.math.BigDecimal;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Random;
public class TestVariance {
	private final long[]	values;//	=	{ 1, 2, 2, 4, 9, 2, 200, 4, 4, 2 };
	private static final int FRACTION_NUM	=	10;
	//	create random data array
	public TestVariance(final int N)
	{
		final Random r = new Random(2233995);
		
		values	=	new long[N];
		
		for (int i = 0; i < N; ++i) {
			long v = 0;
			final int J = 8;
			// gaussian
			for (int j = 0; j < J; ++j) {
				v += r.nextInt(Math.max(N / J, 32 * 1024));
			}
			values[i] = v;
		}
	}
	
	//	receive data array
	public TestVariance(final long[] values)
	{
		this.values	=	new long[values.length];
		for ( int i = 0; i < values.length; ++i )
			this.values[i] = values[i];
	}
	
	//	formulae from http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
	//	all get*Variance*() methods except Kurtosis have two kinds of version,
	//	one uses double, another BigDecimal to compare results with each other
	public double getNaiveVarianceP()
	{
		long	sum	=	0;
		long	sumSqr	=	0;
		
		for ( int i = 0; i < values.length; ++i )
		{
			sum	+=	values[i];
			sumSqr	+=	values[i] * values[i];
		}
		
		double	mean	=	(double)sum / values.length;
		double	variance	=	(double)(sumSqr - sum * mean) / (values.length - 1);
		
		return	variance;
	}
	
	public BigDecimal getNaiveVarianceBD()
	{
		long	sum	=	0;
		long	sumSqr	=	0;
		
		for ( int i = 0; i < values.length; ++i )
		{
			sum	+=	values[i];
			sumSqr	+=	values[i] * values[i];
		}
		
		BigDecimal	mean	=	new BigDecimal(sum).divide(new BigDecimal(values.length), FRACTION_NUM, BigDecimal.ROUND_UP);
		BigDecimal	variance	=	new BigDecimal(sumSqr).subtract(new BigDecimal(sum).multiply(mean)).divide(new BigDecimal(values.length - 1), FRACTION_NUM, BigDecimal.ROUND_UP);	
		
		return	variance;
	}
	
	public double getTwoPassVarianceP()
	{
		long	sum1	=	0;
		long	sum2	=	0;
		for ( int i = 0; i < values.length; ++i )
		{
			sum1	+=	values[i];
		}
		
		double	mean	=	(double)sum1 / values.length;
		for ( int i = 0; i < values.length; ++i )
		{
			sum2	+=	Math.pow((values[i] - mean), 2);
		}
		double	variance	=	(double)sum2 / (values.length - 1);
		
		return variance;
	}
	
	public BigDecimal getTwoPassVarianceBD()
	{
		long	sum1	=	0;
		BigDecimal	sum2	=	new BigDecimal("0.0");
		for ( int i = 0; i < values.length; ++i )
		{
			sum1	+=	values[i];
		}
		
		BigDecimal	mean	=	new BigDecimal(sum1).divide(new BigDecimal(values.length), FRACTION_NUM, BigDecimal.ROUND_UP);
		for ( int i = 0; i < values.length; ++i )
		{
			BigDecimal	data	=	new BigDecimal(values[i]).subtract(mean);
			sum2	=	sum2.add(data.multiply(data));
		}
		BigDecimal	variance	=	sum2.divide(new BigDecimal(values.length - 1), FRACTION_NUM, BigDecimal.ROUND_UP);
		
		return variance;
	}
	
	public double getCompensatedVarianceP()
	{
		long	sum1	=	0;
		
		for ( int i = 0; i < values.length; ++i )
		{
			sum1	+=	values[i];
		}
		
		double	mean	=	(double)sum1 / values.length;
		
		long	sum2	=	0;
		long	sum3	=	0;
		
		for ( int i = 0; i < values.length; ++i )
		{
			sum2	+=	Math.pow((values[i] - mean), 2);
			sum3	+=	(values[i] - mean);
		}
		
		double	variance	=	((double)sum2 - (double)Math.pow(sum3, 2) / values.length) / (values.length - 1);
		
		return	variance;
	}
	
	public BigDecimal getCompensatedVarianceBD()
	{
		long	sum1	=	0;
		
		for ( int i = 0; i < values.length; ++i )
		{
			sum1	+=	values[i];
		}
		
		BigDecimal	mean	=	new BigDecimal(sum1).divide(new BigDecimal(values.length), FRACTION_NUM, BigDecimal.ROUND_UP);
		
		BigDecimal	sum2	=	new BigDecimal("0.0");
		BigDecimal	sum3	=	new BigDecimal("0.0");
		
		for ( int i = 0; i < values.length; ++i )
		{
			BigDecimal	data	=	new BigDecimal(values[i]).subtract(mean);
			sum2	=	sum2.add(data.multiply(data));
			sum3	=	sum3.add(data);
		}
		BigDecimal	variance	=	sum2.subtract(sum3.multiply(sum3).divide(new BigDecimal(values.length), FRACTION_NUM, BigDecimal.ROUND_UP)).divide(new BigDecimal(values.length - 1), FRACTION_NUM, BigDecimal.ROUND_UP);
		
		return	variance;
	}
	
	public double getOnlineVarianceP()
	{
		double	mean	=	0.0D;
		double	M2	=	0.0D;
		
		for ( int i = 0; i < values.length; ++i )
		{
			int	n	=	i + 1;
			double	delta	=	values[i] - mean;
			mean	+=	delta / n;
			M2	+=	delta * (values[i] - mean);
		}
		
//		double	variance_n	=	M2 / values.length;
		double	variance	=	M2 / (values.length - 1);
		
		return	variance;
	}
	
	public BigDecimal getOnlineVarianceBD()
	{
		BigDecimal	mean	=	new BigDecimal("0.0");
		BigDecimal	M2	=	new BigDecimal("0.0");
		
		for ( int i = 0; i < values.length; ++i )
		{
			BigDecimal	delta	=	new BigDecimal(values[i]).subtract(mean);
			mean	=	mean.add(delta.divide(new BigDecimal(i + 1), FRACTION_NUM, BigDecimal.ROUND_UP));
			M2	=	M2.add(delta.multiply(new BigDecimal(values[i]).subtract(mean)));
		}
		
//		double	variance_n	=	M2 / values.length;
		BigDecimal	variance	=	M2.divide(new BigDecimal(values.length - 1), FRACTION_NUM, BigDecimal.ROUND_UP);
		
		return	variance;
	}
	
	public double getWeightedIncrementalVarianceP()
	{
		double	mean	=	0.0D;
		double	S	=	0.0D;
		double	sumWeight	=	0.0D;
		
		for ( int i = 0; i < values.length; ++i )
		{
			double	temp	=	1.0D + sumWeight;
			double	Q	=	values[i] - mean;
			double	R	=	Q * 1.0D / temp;
			S	+=	sumWeight * Q * R;
			mean	+=	R;
			sumWeight	=	temp;
		}
		
		double	variance	=	S / (sumWeight - 1);
		
		return	variance;
	}
	
	public BigDecimal getWeightedIncrementalVarianceBD()
	{
		BigDecimal	mean	=	new BigDecimal("0.0");
		BigDecimal	S	=	new BigDecimal("0.0");
		BigDecimal	sumWeight	=	new BigDecimal("0.0");
		BigDecimal	one	=	new BigDecimal("1.0");
		
		for ( int i = 0; i < values.length; ++i )
		{
			BigDecimal	temp	=	sumWeight.add(one);
			BigDecimal	Q	=	new BigDecimal(values[i]).subtract(mean);
			BigDecimal	R	=	Q.divide(temp, FRACTION_NUM, BigDecimal.ROUND_UP);
			S	=	S.add(sumWeight.multiply(Q).multiply(R));
			mean	=	mean.add(R);
			sumWeight	=	temp;
		}
		
		BigDecimal	variance	=	S.divide(sumWeight.subtract(one), FRACTION_NUM, BigDecimal.ROUND_UP);
		
		return	variance;
	}
	
	public double getOnlineKurtosisVarianceP()
	{
		int	n	=	0;
		double	mean	=	0.0D;
		double	M2	=	0.0D;
		double	M3	=	0.0D;
		double	M4	=	0.0D;
		
		for ( int i = 0; i < values.length; ++i )
		{
			int	n1	=	n;
			++n;
			double	delta	=	values[i] - mean;
			double	deltaN	=	delta / n;
			double	deltaN2	=	deltaN * deltaN;
			double	term1	=	delta * deltaN * n1;
			mean	+=	deltaN;
			M4	+=	term1 * deltaN2 * (n * n - 3 * n + 3) + 6 * deltaN2 * M2 - 4 * deltaN * M3;
			M3	+=	term1 * deltaN * (n - 2) - 3 * deltaN * M2;
			M2	+=	term1;
		}
		double	kurtosis	=	(n * M4) / (M2 * M2) - 3;
		return	kurtosis;
	}
	//	execute test and compare results to see which is the best
	private void executeTest(Map>String, MethodLog<	methodLogMap) throws IllegalArgumentException, IllegalAccessException, InvocationTargetException
	{
		long	time0, time1;
		Method[]	methodArray	=	new TestVariance(values).getClass().getDeclaredMethods();
		List>ResultLog<	resultLogList	=	new ArrayList>ResultLog<(methodArray.length);
		
		Double2IntOpenHashMap	stdCandidateMap	=	new Double2IntOpenHashMap();
		resultLogList.removeAll(resultLogList);
		DecimalFormat	df	=	new DecimalFormat("0.##########E0");
		df.setMaximumFractionDigits(FRACTION_NUM);
		//	execute calculation using method name
		//	then, save it to the map
		for ( Method m : methodArray )
		{
			final String	name	=	m.getName();
			if ( true == name.startsWith("get") )
			{
				if ( null == methodLogMap.get(name) )
					methodLogMap.put(name, new MethodLog(name));
				
				if ( true == name.endsWith("P") )
				{
					time0	=	System.nanoTime();
					Double	variance	=	(Double) m.invoke(this, null);
					time1	=	System.nanoTime();
//					System.out.printf("%-35s\t%f\t%d ns\n", name, variance, (time1 - time0));
					if ( 0 < variance.doubleValue() )
					{
						stdCandidateMap.add(Double.valueOf(df.format(variance.doubleValue())), 1);
					}
					resultLogList.add(new ResultLog(name, variance, time1 - time0));
				}
				else if ( true == name.endsWith("BD") )
				{
					time0	=	System.nanoTime();
					BigDecimal	variance	=	(BigDecimal) m.invoke(this, null);
					time1	=	System.nanoTime();
//					System.out.printf("%-35s\t%f\t%d ns\n", name, variance, (time1 - time0));
					if ( 0 < variance.doubleValue() )
					{
						stdCandidateMap.add(Double.valueOf(df.format(variance.doubleValue())), 1);
					}
					resultLogList.add(new ResultLog(name, variance, time1 - time0));
				}
			}
		}
		
//		String	std	=	null;
		BigDecimal	std	=	new BigDecimal("0.0");
		int	num	=	0;
		//	decide most occurring value as standard among calculated variances
		for ( Entry>Double, Integer< item : stdCandidateMap.entrySet() )
		{
//			System.out.println(item.getKey() + " -> " + item.getValue());
			if ( num < item.getValue() )
			{
				num	=	item.getValue();
//				std	=	df.format(item.getKey());
				std	=	new BigDecimal(df.format(item.getKey()));
//				System.out.println("\t" + std + " selected because " + num + " exists");
			}
		}
		System.out.println("\n----------- " + values.length + " values\tstandard: " + std);
		final class ResultLogComparator implements Comparator>ResultLog<
		{
			@Override
			public int compare(ResultLog r1, ResultLog r2)
			{
				long	result	=	r1.getTime() - r2.getTime();
				return	0 < result ? 1 : (result < 0 ? -1 : 0);
			}
		}
		Collections.sort(resultLogList, new ResultLogComparator());
		int	rank	=	0;
		BigDecimal	difference	=	new BigDecimal("0.000001");
		for ( ResultLog rl : resultLogList )
		{
//			String	strStd	=	std;
//			String	strRes	=	df.format(rl.getVariance());
			BigDecimal	res	=	new BigDecimal(df.format(rl.getVariance()));
			boolean	cmpRes	=	true;
//			for ( int j = 0; j < Math.min(strStd.length(), strRes.length()); ++j )
//				if ( strStd.charAt(j) != strRes.charAt(j) )
//					cmpRes	=	false;
//			System.out.println(strStd + "\tvs.\t" + strRes + "\t" + (cmpRes == true ? "same" : "different"));
//			System.out.printf("std: %f\tres: %f\tres.subtract(std): %f\n", std.doubleValue(), res.doubleValue(), Math.abs(res.subtract(std).doubleValue()));
			if ( difference.doubleValue() < Math.abs(res.subtract(std).doubleValue()) )
			{
				cmpRes	=	false;
//				System.out.println(std.doubleValue() + "\tvs.\t" + res.doubleValue() + "\t" + (cmpRes == true ? "same" : "different"));
			}
				
//			if ( Math.abs(rl.getVariance().subtract(std).doubleValue()) < PRECISION_DIFFERENCE )
			//	show results with ranks sorted by execution time if result is the same as standard
			//	then, save method name with adding rank point to the methodLogMap
			if ( true == cmpRes )
			{
//				System.out.println("[" + rank + "]\t" + rl.getName() + "\t" + rl.getVariance() + "\t" + rl.getTime());
				if ( 6 < Long.toString(rl.getTime()).length() )
					System.out.format("[%d]\t%-45s %15.6f %10d ms\n", rank, rl.getName(), rl.getVariance(), rl.getTime() / 1000 / 1000);
				else
					System.out.format("[%d]\t%-45s %15.6f %10d ns\n", rank, rl.getName(), rl.getVariance(), rl.getTime());
				methodLogMap.get(rl.getName()).addPoint(rank);
				++rank;
			}
			//	show results with [X] mark if result is different from standard
			//	then, save method name with adding current values.length to the list of the methodLogMap
			else
			{
//				System.out.println("[X]\t" + rl.getName() + "\t" + rl.getVariance() + "\t" + rl.getTime());
				if ( 6 < Long.toString(rl.getTime()).length() )
					System.out.format("[X]\t%-45s %15.6f %10d ms\n", rl.getName(), rl.getVariance(), rl.getTime() / 1000 / 1000);
				else
					System.out.format("[X]\t%-45s %15.6f %10d ns\n", rl.getName(), rl.getVariance(), rl.getTime());
				methodLogMap.get(rl.getName()).addWrongResultNumber(values.length);
			}
		}
	}
	/*
long num = ;
long den = ;
// compute num and den
double r = num / den + (double) (num - den * (num / den)) / den ;
	 */
	public static void main(String[] args) throws IOException, IllegalArgumentException, IllegalAccessException, InvocationTargetException
	{
		Map>String, MethodLog<	methodLogMap	=	new HashMap>String, MethodLog<();
		
		if ( args.length == 1 && args[0].equals("data") )
		{
			File	dir	=	new File("/home/hchung/programming/data");
			File[]	files	=	dir.listFiles();
			for ( File f : files )
			{
				long	values[];
				if ( f.isFile() )
				{
					System.out.println("=============================\t" + f.getName());
					BufferedReader	br	=	new BufferedReader(new FileReader(f));
					List>Long<	list	=	new ArrayList>Long<();
					String s;
		
					while ( null != ( s = br.readLine() ) )
					{
						list.add(Long.valueOf(s));
						
					}
					values	=	new long[list.size()];
					int	index	=	0;
					for ( Long l : list )	values[index++]	=	l;
					System.out.println("values.length = " + values.length);
				    br.close();
				}
				else
				{
					continue;
				}
				new TestVariance(values).executeTest(methodLogMap);
			}
		}
		else
		{
			for ( int N = 8 * 1024; N < 128 * 1024 * 1024; N *= 2 )
			{
				new TestVariance(N).executeTest(methodLogMap);
			}
		}
		
		List>MethodLog<	methodLogList	=	new ArrayList>MethodLog<();
		System.out.println();
		for ( Entry>String, MethodLog< item : methodLogMap.entrySet() )
		{
			methodLogList.add(item.getValue());
		}
		
		class MethodLogComparator implements Comparator>MethodLog<
		{
			public int compare(MethodLog ml1, MethodLog ml2)
			{
				int	result	=	ml1.getWrongResultNumberList().size() - ml2.getWrongResultNumberList().size();;
				if ( 0 != result )
					return	0 < result ? 1 : -1;
				
				return	ml1.getPoint() - ml2.getPoint();
			}
		}
		Collections.sort(methodLogList, new MethodLogComparator());
		System.out.println();
		//	show the final result sorted by
		//		1.	least wrong result number list
		//		2.	least execution time
		for ( MethodLog ml : methodLogList )
		{
			System.out.format("%-45s %10d %10d\n", ml.getName(), ml.getPoint(), ml.getWrongResultNumberList().size());
		}		
	}
}
class ResultLog
{
	private String	name;
	private BigDecimal	variance;
	private long	time;
	
	public ResultLog(String name, double variance, long time)
	{
		this.name	=	name;
		this.variance	=	new BigDecimal(variance);
		this.time	=	time;
	}
	
	public ResultLog(String name, BigDecimal variance, long time)
	{
		this.name	=	name;
		this.variance	=	variance;
		this.time	=	time;
	}
	
	public String getName()	{	return	this.name;	}
	public BigDecimal getVariance()	{	return	this.variance;	}
	public long getTime()	{	return	this.time;	}
}
//private static class ValueComparator 
//	implements Comparator>K<{
//	private Map>K, V< map;
//	ValueComparator(Map>K, V< map) {
//		this.map = map;
//	}
//	public int compare(K o1, K o2) {
//		int p = map.get(o1).compareTo(map.get(o2));
//		if (p != 0) {
//			return p;
//		}
//		return o1.compareTo(o2);
//	}
//}
//class MethodLogComparator>K extends Comparable>K<, V extends Comparable>V<< implements Comparator>V<
//{
//	private Map>K, V<	map;
//	MethodLogComparator(Map>K, V< map) { this.map = map; }
//	public int compare(V o1, V o2) {
//		MethodLog	m1	=	(MethodLog) o1;
//		MethodLog	m2	=	(MethodLog) o2;
//		long	result	=	m1.getWrongResultNumberList().size() - m2.getWrongResultNumberList().size();
//		return	0 < result ? 1 : (result < 0 ? -1 : 0);
//	}
//}
class MethodLog
{
	private String	name;
	private int	point;
	private List>Integer<	wrongResultNumberList;
	
	public MethodLog(String name)
	{
		this.name	=	name;
		this.point	=	0;
		this.wrongResultNumberList	=	new ArrayList>Integer<();
	}
	public String getName()	{	return	this.name;	}
	public void addPoint(int point)	{	this.point	+=	point;	}
	public int getPoint()	{	return	this.point;	}
	public void addWrongResultNumber(int num)	{	wrongResultNumberList.add(num);	}
	public List>Integer< getWrongResultNumberList()	{	return	Collections.unmodifiableList(wrongResultNumberList);	}
}
* eclipse에서는 stdCandidateMap.increment()가 정상 컴파일되고 terminal에서는 stdCandidateMap.add()가 정상 컴파일되는 현상 발생: test 필요
$ javac -cp /home/hchung/programming/fastutil-6.2.2/fastutil-6.2.2.jar:. TestVariance.java -Xlint
$ java -cp /home/hchung/programming/fastutil-6.2.2/fastutil-6.2.2.jar:. TestVariance
Comments