/*
* Copyright (c) 2017, Planet Labs
* Author: <even.rouault at spatialys.com>
*
* Permission to use, copy, modify, distribute, and sell this software and
* its documentation for any purpose is hereby granted without fee, provided
* that (i) the above copyright notices and this permission notice appear in
* all copies of the software and related documentation, and (ii) the names of
* Sam Leffler and Silicon Graphics may not be used in any advertising or
* publicity relating to the software without the specific, prior written
* permission of Sam Leffler and Silicon Graphics.
*
* THE SOFTWARE IS PROVIDED "AS-IS" AND WITHOUT WARRANTY OF ANY KIND,
* EXPRESS, IMPLIED OR OTHERWISE, INCLUDING WITHOUT LIMITATION, ANY
* WARRANTY OF MERCHANTABILITY OR FITNESS FOR A PARTICULAR PURPOSE.
*
* IN NO EVENT SHALL SAM LEFFLER OR SILICON GRAPHICS BE LIABLE FOR
* ANY SPECIAL, INCIDENTAL, INDIRECT OR CONSEQUENTIAL DAMAGES OF ANY KIND,
* OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
* WHETHER OR NOT ADVISED OF THE POSSIBILITY OF DAMAGE, AND ON ANY THEORY OF
* LIABILITY, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
* OF THIS SOFTWARE.
*/

#include "tiffiop.h"
#ifdef ZSTD_SUPPORT
/*
* TIFF Library.
*
* ZSTD Compression Support
*
*/

#include "tif_predict.h"
#include "zstd.h"

#include <stdio.h>

/*
* State block for each open TIFF file using ZSTD compression/decompression.
*/
typedef struct {
        TIFFPredictorState predict;
        ZSTD_DStream*   dstream;
        ZSTD_CStream*   cstream;
        int             compression_level;      /* compression level */
        ZSTD_outBuffer  out_buffer;
        int             state;                  /* state flags */
#define LSTATE_INIT_DECODE 0x01
#define LSTATE_INIT_ENCODE 0x02

        TIFFVGetMethod  vgetparent;            /* super-class method */
        TIFFVSetMethod  vsetparent;            /* super-class method */
} ZSTDState;

#define LState(tif)             ((ZSTDState*) (tif)->tif_data)
#define DecoderState(tif)       LState(tif)
#define EncoderState(tif)       LState(tif)

static int ZSTDEncode(TIFF* tif, uint8* bp, tmsize_t cc, uint16 s);
static int ZSTDDecode(TIFF* tif, uint8* op, tmsize_t occ, uint16 s);

static int
ZSTDFixupTags(TIFF* tif)
{
        (void) tif;
        return 1;
}

static int
ZSTDSetupDecode(TIFF* tif)
{
        ZSTDState* sp = DecoderState(tif);

        assert(sp != NULL);

        /* if we were last encoding, terminate this mode */
        if (sp->state & LSTATE_INIT_ENCODE) {
            ZSTD_freeCStream(sp->cstream);
            sp->cstream = NULL;
            sp->state = 0;
        }

        sp->state |= LSTATE_INIT_DECODE;
        return 1;
}

/*
* Setup state for decoding a strip.
*/
static int
ZSTDPreDecode(TIFF* tif, uint16 s)
{
        static const char module[] = "ZSTDPreDecode";
        ZSTDState* sp = DecoderState(tif);
        size_t zstd_ret;

        (void) s;
        assert(sp != NULL);

        if( (sp->state & LSTATE_INIT_DECODE) == 0 )
            tif->tif_setupdecode(tif);

        if( sp->dstream )
        {
            ZSTD_freeDStream(sp->dstream);
            sp->dstream = NULL;
        }

        sp->dstream = ZSTD_createDStream();
        if( sp->dstream == NULL ) {
            TIFFErrorExt(tif->tif_clientdata, module,
                         "Cannot allocate decompression stream");
            return 0;
        }
        zstd_ret = ZSTD_initDStream(sp->dstream);
        if( ZSTD_isError(zstd_ret) ) {
            TIFFErrorExt(tif->tif_clientdata, module,
                         "Error in ZSTD_initDStream(): %s",
                         ZSTD_getErrorName(zstd_ret));
            return 0;
        }

        return 1;
}

static int
ZSTDDecode(TIFF* tif, uint8* op, tmsize_t occ, uint16 s)
{
        static const char module[] = "ZSTDDecode";
        ZSTDState* sp = DecoderState(tif);
        ZSTD_inBuffer   in_buffer;
        ZSTD_outBuffer  out_buffer;
        size_t          zstd_ret;

        (void) s;
        assert(sp != NULL);
        assert(sp->state == LSTATE_INIT_DECODE);

        in_buffer.src = tif->tif_rawcp;
        in_buffer.size = (size_t) tif->tif_rawcc;
        in_buffer.pos = 0;

        out_buffer.dst = op;
        out_buffer.size = (size_t) occ;
        out_buffer.pos = 0;

        do {
                zstd_ret = ZSTD_decompressStream(sp->dstream, &out_buffer,
                                                 &in_buffer);
                if( ZSTD_isError(zstd_ret) ) {
                    TIFFErrorExt(tif->tif_clientdata, module,
                                "Error in ZSTD_decompressStream(): %s",
                                ZSTD_getErrorName(zstd_ret));
                    return 0;
                }
        } while( zstd_ret != 0 &&
                 in_buffer.pos < in_buffer.size &&
                 out_buffer.pos < out_buffer.size );

        if (out_buffer.pos < (size_t)occ) {
                TIFFErrorExt(tif->tif_clientdata, module,
                    "Not enough data at scanline %lu (short %lu bytes)",
                    (unsigned long) tif->tif_row,
                    (unsigned long) (size_t)occ - out_buffer.pos);
                return 0;
        }

        tif->tif_rawcp += in_buffer.pos;
        tif->tif_rawcc -= in_buffer.pos;

        return 1;
}

static int
ZSTDSetupEncode(TIFF* tif)
{
        ZSTDState* sp = EncoderState(tif);

        assert(sp != NULL);
        if (sp->state & LSTATE_INIT_DECODE) {
                ZSTD_freeDStream(sp->dstream);
                sp->dstream = NULL;
                sp->state = 0;
        }

        sp->state |= LSTATE_INIT_ENCODE;
        return 1;
}

/*
* Reset encoding state at the start of a strip.
*/
static int
ZSTDPreEncode(TIFF* tif, uint16 s)
{
        static const char module[] = "ZSTDPreEncode";
        ZSTDState *sp = EncoderState(tif);
        size_t zstd_ret;

        (void) s;
        assert(sp != NULL);
        if( sp->state != LSTATE_INIT_ENCODE )
            tif->tif_setupencode(tif);

        if (sp->cstream) {
            ZSTD_freeCStream(sp->cstream);
            sp->cstream = NULL;
        }
        sp->cstream = ZSTD_createCStream();
        if( sp->cstream == NULL ) {
            TIFFErrorExt(tif->tif_clientdata, module,
                         "Cannot allocate compression stream");
            return 0;
        }

        zstd_ret = ZSTD_initCStream(sp->cstream, sp->compression_level);
        if( ZSTD_isError(zstd_ret) ) {
            TIFFErrorExt(tif->tif_clientdata, module,
                         "Error in ZSTD_initCStream(): %s",
                         ZSTD_getErrorName(zstd_ret));
            return 0;
        }

        sp->out_buffer.dst = tif->tif_rawdata;
        sp->out_buffer.size = (size_t)tif->tif_rawdatasize;
        sp->out_buffer.pos = 0;

        return 1;
}

/*
* Encode a chunk of pixels.
*/
static int
ZSTDEncode(TIFF* tif, uint8* bp, tmsize_t cc, uint16 s)
{
        static const char module[] = "ZSTDEncode";
        ZSTDState *sp = EncoderState(tif);
        ZSTD_inBuffer   in_buffer;
        size_t          zstd_ret;

        assert(sp != NULL);
        assert(sp->state == LSTATE_INIT_ENCODE);

        (void) s;

        in_buffer.src = bp;
        in_buffer.size = (size_t)cc;
        in_buffer.pos = 0;

        do {
                zstd_ret = ZSTD_compressStream(sp->cstream, &sp->out_buffer,
                                               &in_buffer);
                if( ZSTD_isError(zstd_ret) ) {
                    TIFFErrorExt(tif->tif_clientdata, module,
                                "Error in ZSTD_compressStream(): %s",
                                ZSTD_getErrorName(zstd_ret));
                    return 0;
                }
                if( sp->out_buffer.pos == sp->out_buffer.size ) {
                        tif->tif_rawcc = tif->tif_rawdatasize;
                        TIFFFlushData1(tif);
                        sp->out_buffer.dst = tif->tif_rawcp;
                        sp->out_buffer.pos = 0;
                }
        } while( in_buffer.pos < in_buffer.size );

        return 1;
}

/*
* Finish off an encoded strip by flushing it.
*/
static int
ZSTDPostEncode(TIFF* tif)
{
        static const char module[] = "ZSTDPostEncode";
        ZSTDState *sp = EncoderState(tif);
        size_t          zstd_ret;

        do {
                zstd_ret = ZSTD_endStream(sp->cstream, &sp->out_buffer);
                if( ZSTD_isError(zstd_ret) ) {
                    TIFFErrorExt(tif->tif_clientdata, module,
                                "Error in ZSTD_endStream(): %s",
                                ZSTD_getErrorName(zstd_ret));
                    return 0;
                }
                if( sp->out_buffer.pos > 0 ) {
                        tif->tif_rawcc = sp->out_buffer.pos;
                        TIFFFlushData1(tif);
                        sp->out_buffer.dst = tif->tif_rawcp;
                        sp->out_buffer.pos = 0;
                }
        } while (zstd_ret != 0);
        return 1;
}

static void
ZSTDCleanup(TIFF* tif)
{
        ZSTDState* sp = LState(tif);

        assert(sp != 0);

        (void)TIFFPredictorCleanup(tif);

        tif->tif_tagmethods.vgetfield = sp->vgetparent;
        tif->tif_tagmethods.vsetfield = sp->vsetparent;

        if (sp->dstream) {
            ZSTD_freeDStream(sp->dstream);
            sp->dstream = NULL;
        }
        if (sp->cstream) {
            ZSTD_freeCStream(sp->cstream);
            sp->cstream = NULL;
        }
        _TIFFfree(sp);
        tif->tif_data = NULL;

        _TIFFSetDefaultCompressionState(tif);
}

static int
ZSTDVSetField(TIFF* tif, uint32 tag, va_list ap)
{
	static const char module[] = "ZSTDVSetField";
        ZSTDState* sp = LState(tif);

        switch (tag) {
        case TIFFTAG_ZSTD_LEVEL:
                sp->compression_level = (int) va_arg(ap, int);
                if( sp->compression_level <= 0 ||
                    sp->compression_level > ZSTD_maxCLevel() )
                {
                    TIFFWarningExt(tif->tif_clientdata, module,
                                   "ZSTD_LEVEL should be between 1 and %d",
                                   ZSTD_maxCLevel());
                }
                return 1;
        default:
                return (*sp->vsetparent)(tif, tag, ap);
        }
        /*NOTREACHED*/
}

static int
ZSTDVGetField(TIFF* tif, uint32 tag, va_list ap)
{
        ZSTDState* sp = LState(tif);

        switch (tag) {
        case TIFFTAG_ZSTD_LEVEL:
                *va_arg(ap, int*) = sp->compression_level;
                break;
        default:
                return (*sp->vgetparent)(tif, tag, ap);
        }
        return 1;
}

static const TIFFField ZSTDFields[] = {
        { TIFFTAG_ZSTD_LEVEL, 0, 0, TIFF_ANY, 0, TIFF_SETGET_INT,
          TIFF_SETGET_UNDEFINED,
          FIELD_PSEUDO, TRUE, FALSE, "ZSTD compression_level", NULL },
};

int
TIFFInitZSTD(TIFF* tif, int scheme)
{
        static const char module[] = "TIFFInitZSTD";
        ZSTDState* sp;

        assert( scheme == COMPRESSION_ZSTD );

        /*
        * Merge codec-specific tag information.
        */
        if (!_TIFFMergeFields(tif, ZSTDFields, TIFFArrayCount(ZSTDFields))) {
                TIFFErrorExt(tif->tif_clientdata, module,
                            "Merging ZSTD codec-specific tags failed");
                return 0;
        }

        /*
        * Allocate state block so tag methods have storage to record values.
        */
        tif->tif_data = (uint8*) _TIFFmalloc(sizeof(ZSTDState));
        if (tif->tif_data == NULL)
                goto bad;
        sp = LState(tif);

        /*
        * Override parent get/set field methods.
        */
        sp->vgetparent = tif->tif_tagmethods.vgetfield;
        tif->tif_tagmethods.vgetfield = ZSTDVGetField;	/* hook for codec tags */
        sp->vsetparent = tif->tif_tagmethods.vsetfield;
        tif->tif_tagmethods.vsetfield = ZSTDVSetField;	/* hook for codec tags */

        /* Default values for codec-specific fields */
        sp->compression_level = 9;		/* default comp. level */
        sp->state = 0;
        sp->dstream = 0;
        sp->cstream = 0;
        sp->out_buffer.dst = NULL;
        sp->out_buffer.size = 0;
        sp->out_buffer.pos = 0;

        /*
        * Install codec methods.
        */
        tif->tif_fixuptags = ZSTDFixupTags;
        tif->tif_setupdecode = ZSTDSetupDecode;
        tif->tif_predecode = ZSTDPreDecode;
        tif->tif_decoderow = ZSTDDecode;
        tif->tif_decodestrip = ZSTDDecode;
        tif->tif_decodetile = ZSTDDecode;
        tif->tif_setupencode = ZSTDSetupEncode;
        tif->tif_preencode = ZSTDPreEncode;
        tif->tif_postencode = ZSTDPostEncode;
        tif->tif_encoderow = ZSTDEncode;
        tif->tif_encodestrip = ZSTDEncode;
        tif->tif_encodetile = ZSTDEncode;
        tif->tif_cleanup = ZSTDCleanup;
        /*
        * Setup predictor setup.
        */
        (void) TIFFPredictorInit(tif);
        return 1;
bad:
        TIFFErrorExt(tif->tif_clientdata, module,
                    "No space for ZSTD state block");
        return 0;
}
#endif /* ZSTD_SUPPORT */

/* vim: set ts=8 sts=8 sw=8 noet: */