/*
	Aproksymacja pierwiastka liczb zmiennoprzecinkowych

 Liczba zmiennoprzecinkowa jest postaci M*2^k, gdzie 1.0 <= M < 2.0, 
 k = -127..128. W zależności od k mamy:

 * k > 0 i k nieparzyste

      x = M*2^(2*n+1) -> sqrt(x) = sqrt(M*2^1)*2^n

 * k < 0 i k nieparzyste
	
      x = M*2^(2*n-1) -> sqrt(x) = sqrt(M*2^-1)*2^n
 
 * k parzyste, dowolnego znaku

      x = M*2^(2n) -> sqrt(x) = sqrt(M)*2^n
 
 Na początku tworzona jest tablica (podzielona na trzy części), adresowana 
 m starszymi bitami mantysy, w której zapisywane są pierwiastki z M*2, M, 
 oraz M/2.

 Przy aproksymacji pobierane jest m starszych bitów mantysy, które stają się
 indeksem, wykładnik wybiera którą część tablicy adresujemy. Następnie do
 wykładnika wyniku dodawana jest 1/2 wykładnika pierwiastkowanej liczby.

 Dokładność, gdy brane jest pod uwagę 13 bitów (tak jak w kodzie), jest nie
 gorsza niż 0.05%; przy 10 bitach ok. 0.5%. Więc dla zastosowań graficznych
 może być wystarczająca. W podobny sposób działają rozkazy aproksymujące 
 w SSE2.

 Wojciech Muła
 7.05.2004 12:38:45

 $Date: 2007-06-20 23:14:24 $, $Revision: 1.3 $

*/

#include <math.h>
#include <stdio.h>
#include <stdlib.h>

typedef unsigned int dword;

#define bias 127
typedef union {
	float value;
	struct {
		dword M:23;
		dword E:8;
		dword S:1;
	} fields;
} FLOAT;

#define bits 13			// dokładność (w bitach)
#define size (1l << bits)	// rozmiar 1/3 tablicy
float lookup_sqrt[size*3];	// tablica

// przygotowanie tablic
void prepare_lookup_sqrt() {
	FLOAT f;
	int i;

	f.value = 1.0;
	for (i=0; i<size; i++) {
		f.fields.M = i << (23-bits);
		lookup_sqrt[i + 0*size] = sqrt(f.value);
		lookup_sqrt[i + 1*size] = sqrt(f.value * 2.0);
		lookup_sqrt[i + 2*size] = sqrt(f.value * 0.5);
	}
}

float approx_sqrt(float value) {
	FLOAT f;
	int e;
	int i;

	f.value = value;

	if (f.fields.S) {
		// NaN
		f.fields.S = 0;
		f.fields.E = 0xff;
		f.fields.M = 0x1;
		return f.value;
	}
	
	e = f.fields.E - bias;
	i = f.fields.M >> (23-bits);

	if (e % 2 == +1)
		i += size;
	else
	if (e % 2 == -1)
		i += 2*size;
		
	f.value = lookup_sqrt[i];
	f.fields.E += e/2;

	return f.value;
}


float approx_sqrt2(float value) {
	FLOAT f;
	float t;
	int e;
	int i, fract, index;

	f.value = value;

	if (f.fields.S) {
		// NaN
		f.fields.S = 0;
		f.fields.E = 0xff;
		f.fields.M = 0x1;
		return f.value;
	}
	
	e = f.fields.E - bias;
	i = index = f.fields.M >> (23-bits);

	if (e % 2 == +1)
		i += size;
	else
	if (e % 2 == -1)
		i += 2*size;
		

	if (index == size-1) {
		f.value = lookup_sqrt[i];
		f.fields.E += e/2;
		return f.value;
	}
	else {
#define mask ((1 << bits)-1)
		fract = (f.fields.M & mask);
		if (fract < mask/3)
			f.value = 0.25*lookup_sqrt[i] + 0.75*lookup_sqrt[i+1];
		else
		if (fract < 2*mask/3)
			f.value = 0.5*lookup_sqrt[i] + 0.5*lookup_sqrt[i+1];
		else
			f.value = 0.75*lookup_sqrt[i] + 0.25*lookup_sqrt[i+1];
		
		f.fields.E += e/2;
		return f.value;
#undef mask
	}
}


float approx_sqrt3(float value) {
	FLOAT f;
	float t;
	int e;
	int index, i;

	f.value = value;

	if (f.fields.S) {
		// NaN
		f.fields.S = 0;
		f.fields.E = 0xff;
		f.fields.M = 0x1;
		return f.value;
	}
	
	e = f.fields.E - bias;
	i = index = f.fields.M >> (23-bits);

	if (e % 2 == +1)
		i += size;
	else
	if (e % 2 == -1)
		i += 2*size;
		

	if (index == size-1) {
		f.value = lookup_sqrt[i];
		f.fields.E += e/2;
		return f.value;
	}
	else {
#define mask ((1 << bits)-1)
		t = (f.fields.M & mask) / (float)(mask);
		f.value = t*lookup_sqrt[i] + (1-t)*lookup_sqrt[i+1];
		f.fields.E += e/2;
		return f.value;
#undef mask
	}
}

int main(int argc, char* argv[]) {
	int i = 0;
	int n = 1000000;
	float f, s, a1, a2, a3, e1, e2, e3;

	float min_1 = 1e5, min_2 = 1e5, min_3 = 1e5;
	float max_1 = -1e5, max_2 = -1e5, max_3 = -1e5;
	float avg_1 = 0.0, avg_2 = 0.0, avg_3 = 0.0;
	float range = 100000.0;

	if (argc > 1) {
		range = atof(argv[1]);
		if (range < 1.0) range = 1.0;
	}
	
	prepare_lookup_sqrt();

	srand(time(NULL));
	
	while (i<n) {
		f = (range * rand()/(RAND_MAX+1.0));
		s = sqrt(f);
		a1 = approx_sqrt(f);
		a2 = approx_sqrt2(f);
		a3 = approx_sqrt3(f);
		e1 = 100*(a1-s)/s;
		e2 = 100*(a2-s)/s;
		e3 = 100*(a3-s)/s;

		avg_1 += fabs(e1);
		if (e1 < min_1) min_1 = e1;
		if (e1 > max_1) max_1 = e1;
		
		avg_2 += fabs(e2);
		if (e2 < min_2) min_2 = e2;
		if (e2 > max_2) max_2 = e2;
		
		avg_3 += fabs(e3);
		if (e3 < min_3) min_3 = e3;
		if (e3 > max_3) max_3 = e3;

		//printf("%f %f %f %f\n", s, a1, a2, a3);
		i++;
	}

	printf("f.1 -- min=%f, max=%f, avg=%f\n", min_1, max_1, avg_1/n);
	printf("f.2 -- min=%f, max=%f, avg=%f\n", min_2, max_2, avg_2/n);
	printf("f.3 -- min=%f, max=%f, avg=%f\n", min_3, max_3, avg_3/n);
	return 0;
}
/*
vim ts:8
*/

