#include <stdexcept>
#include "pch.h"
#include "Encoder.h"

/////////////////////////////////////////////////////////////////////////////

static
bool is_matching(const char* ref, const char* other, int len)
{
    for (ojph::ui32 i = 0; i < len; ++i)
        if (ref[i] != other[i] && ref[i] != tolower(other[i]))
            return false;

    return true;
}

/////////////////////////////////////////////////////////////////////////////

extern "C" uint8_t* OpenJPHWrapper::Compress(const uint8_t* uncompressedData, size_t dataSize, char* photometricInterpretation,
    ojph::ui32 width, ojph::ui32 height, ojph::ui32 bitDepth, ojph::ui32 isSigned, ojph::ui32 planarConfig, char* transferSyntax,
    ojph::ui32 quality, size_t* compressedSize)
{
    char prog_order_store[] = "RPCL";
    char* prog_order = prog_order_store;
    char profile_string_store[] = "";
    char* profile_string = profile_string_store;
    char* com_string = NULL;
    ojph::ui32 num_decompositions = 5;
    float quantization_step = -1.0f;
    bool reversible = true;
    int employ_color_transform = -1;

    const int max_precinct_sizes = 33; //maximum number of decompositions is 32
    ojph::size precinct_size[max_precinct_sizes];
    int num_precincts = -1;

    ojph::codestream codestream;
    ojph::infile_base* base = nullptr;

    ojph::size block_size(64, 64);
    ojph::size dims(width, height);
    ojph::size tile_size(0, 0);
    ojph::point tile_offset(0, 0);

    ojph::point image_offset(0, 0);
    const ojph::ui32 initial_num_comps = 4;
    ojph::ui32 max_num_comps = initial_num_comps;
    ojph::ui32 num_components = 1;
    ojph::ui32 num_is_signed = isSigned;
    ojph::ui32 num_bit_depths = bitDepth;
    ojph::ui32 num_comp_downsamps = 0;

    ojph::point downsampling_store[initial_num_comps];
    ojph::point comp_downsampling(1, 1);

    bool tlm_marker = false;
    bool tileparts_at_resolutions = false;
    bool tileparts_at_components = false;

    if (is_matching("1.2.840.10008.1.2.4.201", transferSyntax, strlen(transferSyntax)) || is_matching("1.2.840.10008.1.2.4.202", transferSyntax, strlen(transferSyntax)))
        reversible = true;
    else if (is_matching("1.2.840.10008.1.2.4.203", transferSyntax, strlen(transferSyntax)))
        reversible = false;
    else
        throw std::runtime_error("Unsupported Transfer Syntax: " + std::string(transferSyntax));

    if (is_matching("mono", photometricInterpretation, 4))  // MONOCHROME1/2
    {
        ojph::mem_infile mem_input;
        mem_input.open(uncompressedData, dataSize);

        ojph::param_siz siz = codestream.access_siz();
        siz.set_image_extent(ojph::point(dims.w, dims.h));
        siz.set_num_components(num_components);
        for (ojph::ui32 c = 0; c < num_components; ++c)
            siz.set_component(c, comp_downsampling, bitDepth, isSigned);

        ojph::param_cod cod = codestream.access_cod();
        cod.set_num_decomposition(num_decompositions);
        cod.set_block_dims(block_size.w, block_size.h);
        cod.set_progression_order(prog_order);
        cod.set_color_transform(false);

        cod.set_reversible(reversible);

        if (!reversible)
        {
            float exponent = 1.5f;
            float compression_scale = 0.1f;

            if (bitDepth == 8)
            {
                exponent = 1.42f;              // Lower the curve, smooth compression rise
                compression_scale = 0.14f;     // Slightly stronger compression baseline
            }
            else if (bitDepth == 16)
            {
                exponent = 1.6f;
                compression_scale = 0.08f;
            }
            else if (bitDepth = 32)
            {
                // not calculated yet
            }

            float base_step = 1.0f / powf(2.0f, (float)(bitDepth - 1));
            float quantization_step = base_step * powf((float)quality, exponent) * compression_scale;

            // Clamp for safety
            if (quantization_step < 0.00000004f)
                quantization_step = 0.00000004f;

            codestream.access_qcd().set_irrev_quant(quantization_step);
        }
        codestream.set_planar(false);

        base = &mem_input;

        ojph::mem_outfile memOutput;
        memOutput.open();
        char* com_string = NULL;
        ojph::comment_exchange com_ex;
        if (com_string)
            com_ex.set_string(com_string);

        codestream.write_headers(&memOutput, &com_ex, com_string ? 1 : 0);

        ojph::ui32 next_comp;
        ojph::line_buf* cur_line = codestream.exchange(NULL, next_comp);

        ojph::ui32 height = siz.get_image_extent().y;
        height -= siz.get_image_offset().y;

        void* buffer = NULL;
        ojph::ui32 bytes_per_sample = (bitDepth + 7) >> 3;
        size_t buffer_size = (size_t)width * bytes_per_sample;
        buffer = (ojph::ui8*)malloc(buffer_size);

        for (ojph::ui32 i = 0; i < height; ++i)
        {
            ojph::si32* dp = cur_line->i32;
            std::memcpy(buffer, uncompressedData, buffer_size);

            if (bytes_per_sample == 1)
            {
                if (isSigned) {
                    const ojph::si8* sp = (ojph::si8*)buffer;
                    for (ojph::ui32 i = width; i > 0; --i, ++sp)
                        *dp++ = *sp;
                }
                else {
                    const ojph::ui8* sp = (ojph::ui8*)buffer;
                    for (ojph::ui32 i = width; i > 0; --i, ++sp)
                        *dp++ = (ojph::si32)*sp;
                }
            }

            if (bytes_per_sample == 2) {
                if (isSigned) {
                    const ojph::si16* sp = (ojph::si16*)buffer;
                    for (ojph::ui32 i = width; i > 0; --i, ++sp)
                        *dp++ = *sp;
                }
                else {
                    const ojph::ui16* sp = (ojph::ui16*)buffer;
                    for (ojph::ui32 i = width; i > 0; --i, ++sp)
                        *dp++ = (ojph::si32)*sp;
                }
            }

            if (bytes_per_sample == 4) {
                if (isSigned) {
                    const ojph::si32* sp = (ojph::si32*)buffer;
                    for (ojph::ui32 i = width; i > 0; --i, ++sp)
                        *dp++ = *sp;
                }
                else {
                    const ojph::ui32* sp = (ojph::ui32*)buffer;
                    for (ojph::ui32 i = width; i > 0; --i, ++sp)
                        *dp++ = (ojph::si32)*sp;
                }
            }

            uncompressedData += buffer_size;
            cur_line = codestream.exchange(cur_line, next_comp);
        }

        free(buffer);
        codestream.flush();
        codestream.close();
        base->close();

        const uint8_t* compressedData = memOutput.get_data();

        size_t dataSizeCompressed = memOutput.get_used_size();
        *compressedSize = dataSizeCompressed;
        uint8_t* returnBuffer = (uint8_t*)malloc(dataSizeCompressed);
        if (returnBuffer == nullptr) {
            printf("Memory allocation for compressed data failed.\n");
            return nullptr;
        }
        memcpy(returnBuffer, compressedData, dataSizeCompressed);
        return returnBuffer;
    }
    else if (is_matching("rgb", photometricInterpretation, 3) || is_matching("YBR", photometricInterpretation, 3)) {
        ojph::param_siz siz = codestream.access_siz();

        siz.set_image_extent(ojph::point(image_offset.x + dims.w, image_offset.y + dims.h));

        ojph::mem_infile mem_input;
        mem_input.open(uncompressedData, dataSize);

        num_components = 3;
        num_comp_downsamps = 1;
        num_is_signed = isSigned;
        num_bit_depths = bitDepth;

        ojph::ui32 last_signed_idx = 0, last_bit_depth_idx = 0;
        ojph::ui32 last_downsamp_idx = 0;

        siz.set_num_components(num_components);
        for (ojph::ui32 c = 0; c < num_components; ++c)
            siz.set_component(c, ojph::point(1, 1), num_bit_depths, num_is_signed);

        ojph::param_cod cod = codestream.access_cod();
        cod.set_num_decomposition(3);
        cod.set_block_dims(block_size.w, block_size.h);

        if (num_precincts != -1)
            cod.set_precinct_size(num_precincts, precinct_size);

        cod.set_progression_order(prog_order);                    // progression order; should be RPCL

        if (is_matching("RGB", photometricInterpretation, 3))
            cod.set_color_transform(true);
        else
            cod.set_color_transform(false);

        if (!reversible)
        {
            float exponent = 1.5f;
            float compression_scale = 0.1f;

            if (!planarConfig) {
                if (bitDepth == 8)
                {
                    exponent = 1.42f;              // Lower the curve, smooth compression rise
                    compression_scale = 0.14f;     // Slightly stronger compression baseline
                }
                else if (bitDepth == 16)
                {
                    exponent = 1.6f;
                    compression_scale = 0.08f;
                }
                else if (bitDepth = 32)
                {
                    // not calculated yet
                }
            }
            else
            {
                // Not been tested thoroughly 
                if (bitDepth == 8)
                {
                    exponent = 1.8f;              // Lower the curve, smooth compression rise
                    compression_scale = 0.02f;     // Slightly stronger compression baseline
                }
                else if (bitDepth == 16)
                {
                    exponent = 1.8f;
                    compression_scale = 0.02f;
                }
                else if (bitDepth = 32)
                {
                    // not calculated yet
                }
            }

            float base_step = 1.0f / powf(2.0f, (float)(bitDepth - 1));
            float quantization_step = base_step * powf((float)quality, exponent) * compression_scale;

            // Clamp for safety
            if (quantization_step < 0.00000004f)
                quantization_step = 0.00000004f;

            codestream.access_qcd().set_irrev_quant(quantization_step);
        }

        codestream.set_planar(false);

        if (profile_string[0] != '\0')
            codestream.set_profile(profile_string);
        codestream.set_tilepart_divisions(tileparts_at_resolutions, tileparts_at_components); // need to check this against the standard
        codestream.request_tlm_marker(tlm_marker);                // // need to check this against the standard     

        base = &mem_input;

        ojph::mem_outfile memOutput;
        memOutput.open();
        char* com_string = NULL;
        ojph::comment_exchange com_ex;
        if (com_string)
            com_ex.set_string(com_string);

        codestream.write_headers(&memOutput, &com_ex, com_string ? 1 : 0);

        ojph::ui32 next_comp;
        ojph::line_buf* cur_line = codestream.exchange(NULL, next_comp);

        void* buffer = NULL;
        ojph::ui32 bytes_per_sample = (bitDepth + 7) >> 3;
        size_t buffer_size = (size_t)width * bytes_per_sample * num_components;
        buffer = (ojph::ui8*)malloc(buffer_size);


        if (!planarConfig) {
            for (ojph::ui32 i = 0; i < height; i++)
            {
                std::memcpy(buffer, uncompressedData, buffer_size);

                for (ojph::ui32 c = 0; c < num_components; ++c)
                {
                    ojph::si32* dp = cur_line->i32;  // Destination buffer for current channel

                    if (bytes_per_sample == 1)
                    {
                        const ojph::ui8* sp = (ojph::ui8*)buffer;  // Start at the correct channel (R, G, or B)
                        for (ojph::ui32 x = 0; x < width; ++x)
                            dp[x] = (ojph::si32)sp[3 * x + c];  // Extract the correct component (R, G, or B)                   
                    }

                    else if (bytes_per_sample == 2)
                    {
                        const ojph::ui16* sp16 = (ojph::ui16*)buffer;
                        for (ojph::ui32 x = 0; x < width; ++x)
                            dp[x] = (ojph::si32)sp16[3 * x + c];
                    }

                    cur_line = codestream.exchange(cur_line, next_comp);
                }
                uncompressedData += buffer_size;

            }
        }
        else
        {
            for (ojph::ui32 i = 0; i < height; i++)
            {
                for (ojph::ui32 c = 0; c < num_components; ++c)
                {
                    ojph::si32* dp = cur_line->i32;  // Get buffer for the current color plane

                    // Compute the correct position for the component
                    const ojph::ui8* sp = uncompressedData + (c * width * height * bytes_per_sample)
                        + (i * width * bytes_per_sample);

                    if (bytes_per_sample == 1) {  // 8-bit input case
                        for (ojph::ui32 x = 0; x < width; ++x)
                        {
                            dp[x] = (ojph::si32)sp[x];  // Copy each pixel for the current component
                        }
                    }
                    else if (bytes_per_sample == 2) {  // 16-bit case
                        const ojph::ui16* sp16 = (ojph::ui16*)sp;
                        for (ojph::ui32 x = 0; x < width; ++x)
                        {
                            dp[x] = (ojph::si32)sp16[x];  // Copy each 16-bit value
                        }
                    }

                    cur_line = codestream.exchange(cur_line, next_comp);  // Move to the next component buffer
                }
            }
        }

        free(buffer);
        codestream.flush();
        codestream.close();
        base->close();

        const uint8_t* compressedData = memOutput.get_data();
        size_t dataSizeCompressed = memOutput.get_used_size();
        *compressedSize = dataSizeCompressed;
        uint8_t* returnBuffer = (uint8_t*)malloc(dataSizeCompressed);
        if (returnBuffer == nullptr) {
            printf("Memory allocation for compressed data failed.\n");
            return nullptr;
        }

        memcpy(returnBuffer, compressedData, dataSizeCompressed);
        return returnBuffer;
    }
}
