#include "Decoder.h"

/////////////////////////////////////////////////////////////////////////////

static
bool is_matching(const char* ref, const char* other, int len)
{
    for (ojph::ui32 i = 0; i < len; ++i)
        if (ref[i] != other[i] && ref[i] != tolower(other[i]))
            return false;

    return true;
}

/////////////////////////////////////////////////////////////////////////////

extern "C" uint8_t* OpenJPHWrapper::Decompress(const uint8_t* compressedData, size_t dataSize,
    char* photometricInterpretation, ojph::ui32 planarConfig, size_t* decompressedSize)
{
    ojph::ui32 skipped_res_for_read = 0;
    ojph::ui32 skipped_res_for_recon = 0;
    bool resilient = false;

    ojph::mem_infile mem_input;
    mem_input.open(compressedData, dataSize);
    ojph::codestream codestream;
    ojph::image_out_base* base = NULL;

    if (resilient)
        codestream.enable_resilience();
    codestream.read_headers(&mem_input);
    codestream.restrict_input_resolution(skipped_res_for_read, skipped_res_for_recon);

    ojph::param_siz siz = codestream.access_siz();
    ojph::ui32 width = siz.get_image_extent().x - siz.get_image_offset().x;
    ojph::ui32 height = siz.get_image_extent().y - siz.get_image_offset().y;
    ojph::ui32 numComponents = siz.get_num_components();
    ojph::ui32 bitDepth = siz.get_bit_depth(0);
    bool is_signed = siz.is_signed(0);
    ojph::ui32 bytesPerSample = (bitDepth + 7) / 8;

    size_t totalSize = (size_t)width * height * numComponents * bytesPerSample;
    *decompressedSize = totalSize;
    uint8_t* decompressedBuffer = (uint8_t*)malloc(totalSize);
    if (decompressedBuffer == nullptr) {
        printf("Memory allocation for decompressed data failed.\n");
        return nullptr;
    }

    ojph::ui32 next_comp;
    ojph::line_buf* cur_line = codestream.exchange(NULL, next_comp);
    uint8_t* outPtr = decompressedBuffer;

    codestream.create();

    if (is_matching("mono", photometricInterpretation, 4))
    {
        // used for debugging
        /*FILE* debugFile;
        if (fopen_s(&debugFile, "D:\\OpenJPH resources\\DirectoryOfAll\\DecompressedOutput\\CT_debug_decompressed.raw", "wb") != 0) {
            printf("Failed to open debug_decompressed.raw for writing\n");
            return nullptr;
        }*/

        for (ojph::ui32 i = 0; i < height; ++i)
        {
            ojph::ui32 comp_num;
            ojph::line_buf* line = codestream.pull(comp_num);
            ojph::si32* dp = line->i32;

            if (bytesPerSample == 1)
            {
                ojph::ui8* dest = (ojph::ui8*)outPtr;
                for (ojph::ui32 x = 0; x < width; ++x)
                    dest[x] = (ojph::ui8)dp[x];

                outPtr += width;
                //fwrite(dp, bytesPerSample, width, debugFile); 
            }
            else if (bytesPerSample == 2)
            {
                ojph::ui16* dest = (ojph::ui16*)outPtr;
                for (ojph::ui32 x = 0; x < width; ++x)
                    dest[x] = (ojph::ui16)dp[x];

                outPtr += width * 2;

                //fwrite(dest, bytesPerSample, width, debugFile); 
            }
            else if (bytesPerSample == 4)
            {
                ojph::ui32* dest = (ojph::ui32*)outPtr;
                for (ojph::ui32 x = 0; x < width; ++x)
                    dest[x] = (ojph::ui32)dp[x];

                outPtr += width * 4;
            }
        }

        //fclose(debugFile); 
        //printf("Decompressed data written to debug_decompressed.raw\n");
    }
    else if (is_matching("rgb", photometricInterpretation, 3) || is_matching("YBR", photometricInterpretation, 3))
    {
        if (!planarConfig)
        {
            for (ojph::ui32 i = 0; i < height; ++i)
            {
                for (ojph::ui32 c = 0; c < numComponents; ++c)
                {
                    ojph::line_buf* line = codestream.pull(c);
                    ojph::si32* dp = line->i32;

                    if (bytesPerSample == 1)
                    {
                        ojph::ui8* dest = (ojph::ui8*)outPtr;
                        for (ojph::ui32 x = 0; x < width; ++x)
                            dest[x * numComponents + c] = (ojph::ui8)dp[x];
                    }
                    else if (bytesPerSample == 2)
                    {
                        ojph::ui16* dest = (ojph::ui16*)outPtr;
                        for (ojph::ui32 x = 0; x < width; ++x)
                            dest[x * numComponents + c] = (ojph::ui16)dp[x];
                    }
                }
                outPtr += width * numComponents * bytesPerSample;
            }
        }
        else
        {
            for (ojph::ui32 i = 0; i < height; ++i)
            {
                for (ojph::ui32 c = 0; c < numComponents; ++c)
                {
                    ojph::line_buf* line = codestream.pull(c);
                    ojph::si32* dp = line->i32;

                    size_t offset = (c * width * height) + (i * width);

                    if (bytesPerSample == 1)
                    {
                        ojph::ui8* dest = (ojph::ui8*)outPtr + offset;
                        for (ojph::ui32 x = 0; x < width; ++x)
                            dest[x] = (ojph::ui8)dp[x];
                    }
                    else if (bytesPerSample == 2)
                    {
                        ojph::ui16* dest16 = (ojph::ui16*)(outPtr + offset);
                        for (ojph::ui32 x = 0; x < width; ++x)
                            dest16[x] = (ojph::ui16)dp[x];
                    }
                }
            }
        }
    }

    codestream.close();
    return decompressedBuffer;
}
