#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <math.h>
#include <Mapi.h>
#include <values.h>

#define cs(x) x, strlen(x)

#define INCREASE 128
//#define DEBUG 1
#define DEBUGSQL 1

typedef struct {
    unsigned int *array;
    unsigned int x;
    unsigned int y;
} my_int_arr_t;

typedef struct {
    char **array;
    unsigned int x;
    unsigned int y;
} my_char_arr_t;


char * generate(unsigned int n, unsigned int argc, char *argv[]);
char * generateP(unsigned int n, unsigned int argc, char *argv[]);
char ** make_char_table(char *query, unsigned int *fields, unsigned int *rows);
unsigned int * make_int_table(char *query, unsigned int *fields, unsigned int *rows); 
void free_int_array(my_int_arr_t *input);
void free_char_array(my_char_arr_t *input);


char * evaluatewords(unsigned int len, char *argv[]) {
    int i;
    char *select = (char *) malloc(INCREASE * sizeof(char));

    if (select) {
        int next;
        sprintf(select, "SELECT\n(SELECT COUNT(*) FROM source WHERE word = '%s') AS w%d", argv[0], 1);
    
        for (i = 1; i < len; i++ ) {
            next = strlen(select);
            select = (char *) realloc(select, sizeof(char) * (next + INCREASE));
            sprintf(&select[next], ",\n(SELECT COUNT(*) FROM source WHERE word = '%s') AS w%d", argv[i], i + 1);
        }
        next = strlen(select);
        sprintf(&select[next], ";\n");

        select = (char *) realloc(select, sizeof(char) * (next + 3));
    }

    return select;
}

char * evaluatewordspos(unsigned int argc, char *argv[], unsigned int count_y, char *pos[]) {
    int i;
    char *select = (char *) malloc(INCREASE * sizeof(char));

    if (select) {
        int next;
        sprintf(select, "SELECT\n");

        for (i = 0; i < count_y; i++ ) {
            int j;
            for (j = 0; j < argc; j++) {
                next = strlen(select);
                select = (char *) realloc(select, sizeof(char) * (next + INCREASE));
                int index = i * argc + j;
                sprintf(&select[next], "(SELECT subtotal FROM gram1 WHERE word1 = '%s' AND pos1 = '%s') AS wt%d,\n", argv[j], pos[index], index + 1);
            }
        }
        next = strlen(select);
        sprintf(&select[next]-2, ";\n");
        select = (char *) realloc(select, sizeof(char) * (next + 3));
    }

    return select;
}


char * evaluateposseq(unsigned int n, unsigned int count_x, unsigned int count_y, char *pos[]) {
    int i;
    char *select = (char *) malloc(INCREASE * sizeof(char));

    if (select) {
        int next;
        int maxj = count_x - (n - 1);
        sprintf(select, "SELECT\n");

        for (i = 0; i < count_y; i++ ) {
            int j;
            int index = count_x * i;
            int index2 = maxj * i + 1;

            for (j = 0; j < maxj; j++) {
                next = strlen(select);
                int k;
                sprintf(&select[next], "(SELECT COUNT(*) FROM gram%d WHERE", n);
                for (k = 0; k < n; k++) {
                    next = strlen(select);
                    select = (char *) realloc(select, sizeof(char) * (next + INCREASE));
                    sprintf(&select[next], " pos%d = '%s' AND", k + 1, pos[index + j + k]);
                }
                
                next = strlen(select) - 4;
                sprintf(&select[next], ") AS wt%d,\n", index2 + j);
            }
        }
        next = strlen(select);
        sprintf(&select[next]-2, ";\n");
        select = (char *) realloc(select, sizeof(char) * (next + 3));
    }

    return select;
}

char * evaluateposmargin(unsigned int n, unsigned int count_x, unsigned int count_y, char *pos[]) {
    int i;
    char *select = (char *) malloc(INCREASE * sizeof(char));

    if (select) {
//        int m = n - 1;
        int next;
        sprintf(select, "SELECT\n");
        int maxj = count_x - (n - 1);

        for (i = 0; i < count_y; i++ ) {
            int j;
            int index = count_x * i + 1;
            int index2 = maxj * i + 1;

            for (j = 0; j < maxj; j++) {
                next = strlen(select);
                select = (char *) realloc(select, sizeof(char) * (next + INCREASE));
                sprintf(&select[next], "(SELECT COUNT(*) FROM gram%d WHERE pos%d = '%s') AS wt%d,\n", n, n, pos[index + j], index2 + j);
            }
        }
        next = strlen(select) - 2;
        sprintf(&select[next], ";\n");
        select = (char *) realloc(select, sizeof(char) * (next + 3));
    }

    return select;
}

double *finalprobabilities(unsigned int n, unsigned int count_x, unsigned int *count, unsigned int posword_y, unsigned int *posword, unsigned int *posseq, unsigned int *posmargin, unsigned int *highest) {
    int i;
    int maxj = count_x - (n - 1);

    double *output = (double *) malloc(sizeof(double) * posword_y);

    *highest = 0;

    if (output) { 
        double highestp = -DBL_MAX;
        for (i = 0; i < posword_y; i++ ) {
            int j;
            int index_tt = maxj * i;
            int index_wt = i * count_x;
            double prob = 1.0;
        
            prob += log(posword[index_wt]) - log(count[0]);
            #ifdef DEBUG
            printf("w|t %d/%d * ", posword[index_wt], count[0]);
            #endif

            for (j = 0; j < maxj; j++ ) {
                prob += log(posword[index_wt + j + 1]) - log(count[j + 1]);
                #ifdef DEBUG
                printf("w|t %d/%d * ", posword[index_wt + j + 1], count[j + 1]);
                #endif
                
                prob += log(posseq[index_tt + j]) - log(posmargin[index_tt + j]);
                #ifdef DEBUG
                printf("t|t %d/%d * ", posseq[index_tt + j], posmargin[index_tt + j]);
                #endif
            }
            #ifdef DEBUG
            printf("\n= log(%.10f) -> %.10f\n\n", prob, pow(M_El, prob));
            #endif

            output[i] = prob;
            
            #ifdef DEBUG
            printf("%f < %f\n", highestp, prob);
            #endif

            if (highestp < prob) {
                highestp = prob;
                *highest = i;
            }
        }
    }

    return output;
}

void utilize(char *argv[]) {
    printf("Run with %s << n-gram >> [[sentence]]\n\n"
           "If a positive number is used a fast sparse prone methode is choosen,"
           "a negative number will assume independance between words and pos tags.\n"
           "The current program has a limitation; the sentence length should be equal or bigger than n",
           argv[0]);
}


int main(int argc, char *argv[]) {
    int i;          /* just our usual friend the index */
    char *query;    /* we will store queries here */

    /* We expect at least 1 + n input arguments;
     */
	if (!(argc > 2)) {
        utilize(argv);
        exit(-1);
    }

    /* The first argument will deliver us the n-grams
     * a positive number will indicate we will use a fast method
     * that assumes dependance between bigram word pairs
     * the slow method will use an independant unigram table
     * to compute words
     */
    int P = (int) strtol(argv[1], (char **) NULL, 10);

    if (P == 0) {
        /* first argument should be a number */
        utilize(argv);
        exit(-1);
    }

    /* Normalise the n; so we can use it in the rest of the program
     */
    unsigned int n = (P < 0 ? P * -1 : P);
    
    if ((argc - 2) < n ) {
        /* first argument should be a number */
        utilize(argv);
        exit(-1); 
    }

    /* We quote the input, so we can use it in any futher function */
    my_char_arr_t input;
    input.y = 1;
    input.x = argc - 2;
    input.array = (char **) malloc(sizeof(char *) * input.x);
    
    for (i = 0; i < input.x; i++) {
        input.array[i] = mapi_quote(cs(argv[i+2]));
    }

    /* First we will find out the probabibilities for each word in the source
     * dataset as part of the P(w|t) calculation it will give us insights in
     * wether the word actually exist and therefore if smoothing could be needed.
     */

    query = evaluatewords(input.x, input.array);
#ifdef DEBUGSQL
    printf("%s\n", query);
#endif

    my_int_arr_t count;
    count.array = make_int_table(query, &count.x, &count.y);
    if (!count.array) {
        printf("Something went wrong!\n");
        exit(-1);
    }
    free(query);


/* TODO insert code hier om 0 resultaten naar NULL te mappen 
 *
 * Mogelijke suggesties;
 *   * de hele vergelijking weglaten kan wat exponentiele resultaten opleveren
 *   * mappen door *WEETNIET* en weetniet als smooth waarde nemen
 * 
 */


    /* With the prior knowledge we can now search for a tag sequence that could fit
     * our source data
     * */ 

    my_char_arr_t pos;
    if (P > 0)
        query = generate(n, input.x, input.array);
    else
        query = generateP(n, input.x, input.array);

    if (query == NULL) {
        fprintf(stderr, "Unexpected error at PoS query generation\n");
        exit (-1);
    }

#ifdef DEBUGSQL
    printf("%s\n", query);
#endif

    pos.array = make_char_table(query, &pos.x, &pos.y);
    free(query);

    if (pos.y == 0) {
        fprintf(stderr, "No results where found :(\n");
        exit (-1);
    }


    /* Now find out the amount of words from our source that match a certain pos tag
     * TODO: handle zero!
     */

    my_int_arr_t posword;
    query = evaluatewordspos(input.x, input.array, pos.y, pos.array);
#ifdef DEBUGSQL
    printf("%s\n", query);
#endif
    posword.array = make_int_table(query, &posword.x, &posword.y);
    posword.y = (posword.x / input.x);
    posword.x = input.x; 
    free(query);

    /* Find the amount of PoS tags that match the n'th pos tag in a sequence
     */
    my_int_arr_t posmargin;
    query = evaluateposmargin(n, pos.x, pos.y, pos.array);
#ifdef DEBUGSQL
    printf("%s\n", query);
#endif
    posmargin.array = make_int_table(query, &posmargin.x, &posmargin.y);
    posmargin.y = (posmargin.x / input.x);
    posmargin.x = input.x;
    free(query);

    /* Find the amount of PoS sequences we have observed in our previous steps
     */
    my_int_arr_t posseq;
    query = evaluateposseq(n, pos.x, pos.y, pos.array);
#ifdef DEBUGSQL
    printf("%s\n", query);
#endif
    posseq.array = make_int_table(query, &posseq.x, &posseq.y);
    posseq.y = (posseq.x / input.x);
    posseq.x = input.x;
    free(query);

    unsigned int highest;
    double *final = finalprobabilities(n, input.x, count.array, posword.y, posword.array, posseq.array, posmargin.array, &highest);
    double forward = 0.0;

    for (i = 0; i < pos.y; i++) {
        int j;
        printf("%c ", (i == highest ? '*' : ' ')); 
        for (j = 0; j < pos.x; j++) {
            printf ("%s ", pos.array[i * pos.x + j]);
        }
        double linear = pow(M_El, final[i]);
        forward += linear;
        printf("\t\tlog(%f)\t=\t%f\n", final[i], linear);
    }

    printf("\t\t\t\t\tsum:\t%f\n", forward);
    
    free_char_array(&input);
    free_int_array(&count);
    free_int_array(&posmargin);
    free_int_array(&posseq);
    free_int_array(&posword);
    free_char_array(&pos);
    free(final);
}


char * generate(unsigned int n, unsigned int argc, char *argv[]) {
		char *select = (char *) malloc(1024 * sizeof(char));
		char *from   = (char *) malloc(1024 * sizeof(char));
		char *where  = (char *) malloc(1024 * sizeof(char));
		int i;

		sprintf(select, "SELECT\ns1.pos1 AS p1");
		sprintf(from, "FROM\ngram%d AS s1", n);

    
		sprintf(where, "WHERE\ns1.word1 = '%s'", argv[0]);

		for (i = 1; i < n && i < argc; i++) {
			int next = strlen(where);
      		sprintf(&where[next], " AND s1.word%d = '%s'", i + 1, argv[i]);
		}

		for (i = 2; i < (argc + 2) - n; i++) {
			int next, j;
			next = strlen(select);
			sprintf(&select[next], ",\ns%d.pos1 AS p%d", i, i);
			
			next = strlen(from);
			sprintf(&from[next], ",\ngram%d AS s%d", n, i);

			next = strlen(where);
			for (j = 1; j < n; j++) {
				next = strlen(where);
				sprintf(&where[next], " AND\ns%d.word%d = '%s' AND s%d.pos%d = s%d.pos%d",
						     i, j, argv[i + j - 2], i, j , i - 1, j + 1);
			}
			next = strlen(where);
			sprintf(&where[next], " AND s%d.word%d = '%s'",
					     i, n, argv[i + n - 2]);

		}

        for (i = 2; i <= n; i++) {
			int next;
			next = strlen(select);
            int from = argc - (n - 1);
            from = (from < 1 ? 1 : from);
			sprintf(&select[next], ",\ns%d.pos%d AS p%d", from, i, from + (i - 1));
        }

        char *output = malloc(sizeof(char) * (strlen(select) + strlen(from) + strlen(where) + 4));
        if (output) {
    		sprintf(output, "%s\n%s\n%s;\n", select, from, where);
        }

		free(select);
		free(from);
		free(where);

        return output;
}

char * generateP(unsigned int n, unsigned int argc, char *argv[]) {
    if (n > argc)
        return NULL;

    char *select = (char *) malloc(1024 * sizeof(char));
    char *from   = (char *) malloc(1024 * sizeof(char));
    char *where  = (char *) malloc(1024 * sizeof(char));
    int i;
    
    *where = '\0';
    *from = '\0';
    *select = '\0';

    for (i = 0; i < (argc - (n - 1)); i++) {
        int nexts = strlen(select);
        int nextf = strlen(from);
        sprintf(&select[nexts], ", s%d.pos1 AS t%d", i + 1, i + 1);
        sprintf(&from[nextf], ", gram%dp AS s%d", n, i + 1);
    }
    int prev = strlen(where);

    for (i = 0; i < (argc - (n - 1)); i++) {
        int j;
        int next = strlen(where);
        sprintf(&where[next], " AND s%d.pos1 IN (SELECT pos1 FROM gram1 WHERE word1 = '%s')", i + 1, argv[i]);
        for (j = 1; j < n; j++) {
            next = strlen(where);
            prev = next;
            sprintf(&where[next], " AND s%d.pos%d = s%d.pos%d", i + 1, j + 1, i + 2, j);
        }
        next = strlen(where);
        sprintf(&where[next], "\n");
    }

    where[prev] = '\0';

    for (i = 1; i < n; i++) {
        int nexts = strlen(select);
        int nextw = strlen(where);
        sprintf(&select[nexts], ", s%d.pos%d AS t%d", argc - (n - 1), i + 1, argc - n + i + 1);
        sprintf(&where[nextw], " AND s%d.pos%d IN (SELECT pos1 FROM gram1 WHERE word1 = '%s')", argc - (n - 1), i + 1, argv[argc - n + i]);
    }
    
    char *output = (char *) malloc(sizeof(char *) * (strlen(&select[2]) + strlen(&from[2]) + strlen(where) + 16));

    sprintf(output, "SELECT DISTINCT \n%s\nFROM\n%s\nWHERE\n%s\n", &select[2], &from[2], &where[4]);

    free(select);
    free(from);
    free(where);

    return output;
}




/* ------------------------------------- UTILITY FUNCTIONS ------------------------------------- */

/* cleans up the mess from a int_array */
void free_int_array(my_int_arr_t *input) {
    free(input->array);
}

/* cleans up the mess from a char_array */
void free_char_array(my_char_arr_t *input) {
    int i;
    for (i = 0; i < (input->x * input->y); i++) {
        free(input->array[i]);
    }
    free(input->array);
}

/* creates a char table from the MonetDB result set */
char ** make_char_table(char *query, unsigned int *fields, unsigned int *rows) {
    char **output = NULL;

    /* Log into the database */
    Mapi connection = mapi_connect("localhost", 0, "monetdb", "monetdb", "sql", NULL);
    if (mapi_error(connection)) {
        mapi_explain(connection, stderr);
        return NULL;
    }

    MapiHdl mapi_hdl = mapi_query(connection, query);
    *rows = mapi_get_row_count(mapi_hdl);

    if (rows > 0) {
        printf("Result table:\n");
        *fields = mapi_get_field_count(mapi_hdl);

        output = (char **) malloc(sizeof(char *) * *fields * *rows);
        /* optimalisation; instead of pos tags parse here the index to them */

        if (output) {
            int i;
            for (i = 0; i < *rows; i++) {
                int j;
                mapi_fetch_row(mapi_hdl);

                for (j = 0; j < *fields; j++) {
                    char *result = mapi_fetch_field(mapi_hdl, j);
                    output[i * *fields + j] = mapi_quote(cs(result));
                    printf("%s ", output[i * *fields + j]);
                 }
                printf("\n");
            }
        }
    } else {
        printf("*NO* results\n");
        *fields = 0;
        return NULL;
    }
    mapi_close_handle(mapi_hdl);
    mapi_destroy(connection);

    printf("\n");

    return output;
}

/* creates an integer array of the monetdb resultset */
unsigned int * make_int_table(char *query, unsigned int *fields, unsigned int *rows) {
    unsigned int *output = NULL;

    /* Log into the database */
    Mapi connection = mapi_connect("localhost", 0, "monetdb", "monetdb", "sql", NULL);
    if (mapi_error(connection)) {
        mapi_explain(connection, stderr);
        return NULL;
    }

    MapiHdl mapi_hdl = mapi_query(connection, query);
    *rows = mapi_get_row_count(mapi_hdl);

    if (rows > 0) {
        printf("Result table:\n");
        *fields = mapi_get_field_count(mapi_hdl);

        output = (unsigned int *) malloc(sizeof(unsigned int) * *fields * *rows);

        if (output) {
            int i;
            for (i = 0; i < *rows; i++) {
                int j;
                mapi_fetch_row(mapi_hdl);

                for (j = 0; j < *fields; j++) {
                    output[i * *fields + j] = (int) strtoul(mapi_fetch_field(mapi_hdl, j), (char **) NULL, 10);;
                    printf("%d ", output[i * *fields + j]);
                 }
                printf("\n");
            }
        }
    } else {
        printf("*NO* results\n");
        *fields = 0;
        return NULL;
    }
    mapi_close_handle(mapi_hdl);
    mapi_destroy(connection);

    printf("\n");

    return output;
}

