#include <stdio.h>
#include <stdlib.h>
#include <math.h>

#define INF 1e9

// Function to find the minimum of three numbers
int min(int a, int b, int c) 
{
    if (a < b && a < c) 
        return a;
    if (b < a && b < c) 
        return b;
    return c;
}

// Function to calculate the distance between two points (elements of the time series)
int distance(int x, int y) 
{
    return abs(x - y); // Use Euclidean distance or any other distance metric
}

// Function to perform Dynamic Time Warping and obtain the warp path
void dtw(int *s1, int *s2, int n, int m) 
{
    // Create a 2D array for DP
    int dp[n+1][m+1];
    // Create a 2D array to store the warp path
    int path[n+1][m+1];
    
    int total_distance;
    int count =0;
    float normalised_distance;

    // Initialize the DP and path arrays. Indexing is from 0 to n-1
    for (int i = 0; i <= n; i++) 
    {
        for (int j = 0; j <= m; j++) 
        {
            dp[i][j] = INF;
            path[i][j] = 0;
        }
    }
    dp[0][0] = 0;

    // Fill the DP array
    for (int i = 1; i <= n; i++) 
    {
        for (int j = 1; j <= m; j++) 
        {
            int cost = distance(s1[i-1], s2[j-1]);
            int min_cost = min(dp[i-1][j], dp[i][j-1], dp[i-1][j-1]);
            dp[i][j] = cost + min_cost;

            // Update the path
            if (dp[i-1][j] == min_cost) path[i][j] = 1; // Move up
            else if (dp[i][j-1] == min_cost) path[i][j] = 2; // Move left
            else path[i][j] = 3; // Move diagonally
        }
    }

    // Backtrack to find the warp path
    int i = n, j = m;
    //printf("(%d)\n", dp[n][m]);
    total_distance = dp[n][m];       // Add the last element
    while (i > 0 && j > 0) 
    {
        //printf("(%d, %d)\n", i-1, j-1); // Print the current cell (warp path)

        //printf("(%d)\n", min(dp[i-1][j], dp[i][j-1],dp[i-1][j-1]));
        total_distance = total_distance + min(dp[i-1][j],dp[i][j-1],dp[i-1][j-1]);         // Keep adding total distance to minimum value
        //printf("(%d)\n", total_distance);
        count = count +1;
        int direction = path[i][j];
        if (direction == 1) i--; // Move up
        else if (direction == 2) j--; // Move left
        else 
        { // Move diagonally
            i--;
            j--;
        }
    }
    
    // Print the cost matrix
    for (int i = n; i >= 1; i--) 
    {
        for (int j = 1; j <= m; j++) 
        {
            printf(" %d", dp[i][j]);
            printf("\t");
        }
        printf("\n");
    }
    
    //Get the normalised distance
    normalised_distance = (float) total_distance/(count);
    //printf(" Final distance %d", total_distance);

}

int main() {
    // Example usage
    //int s1[] = {1,7,4,8,2,9,6,5,2,0};
    //int s2[] = {1,2,8,5,5,1,9,4,6,5};
    
    int s1[] = {7, 1, 2, 5, 9};
    int s2[] = {1, 8, 0, 4, 4, 2, 0};

    int n = sizeof(s1) / sizeof(s1[0]);
    int m = sizeof(s2) / sizeof(s2[0]);

    printf("Warp Path:\n");
    dtw(s1, s2, n, m);

    return 0;
}
