package dev.codeflush.baseencoder;

import java.io.IOException;
import java.io.InputStream;
import java.io.Reader;
import java.nio.CharBuffer;

public class BaseDecoder extends InputStream {

    private static final int BUFFER_SIZE = 8192;

    private final BaseEncoding encoding;
    private final Reader in;
    private final CharBuffer buffer;
    private final int maxBytesPerRead;
    private int bitIndex;
    private byte unfinishedByte;

    public BaseDecoder(BaseEncoding encoding, Reader in) {
        this.encoding = encoding;
        this.in = in;
        this.buffer = CharBuffer.allocate(BUFFER_SIZE);
        this.maxBytesPerRead = roundUp(BUFFER_SIZE * encoding.bitBlockSize(), Byte.SIZE) / Byte.SIZE;
        this.bitIndex = 0;
        this.unfinishedByte = 0;
    }

    @Override
    public int read() throws IOException {
        final byte[] singleByteBuffer = new byte[1];

        if (read(singleByteBuffer) == -1) {
            return -1;
        }

        return singleByteBuffer[0] & 0xFF;
    }

    @Override
    public int read(byte[] buffer) throws IOException {
        return read(buffer, 0, buffer.length);
    }

    @Override
    public int read(byte[] buffer, int offset, int length) throws IOException {
        if (length > this.maxBytesPerRead) {
            length = this.maxBytesPerRead;
        }

        final int totalBitsRequested = length * Byte.SIZE;
        final int bitsReady = this.bitIndex;
        final int bitsLeft = totalBitsRequested - bitsReady;
        final int bitsPerChar = this.encoding.bitBlockSize();
        int charsToRead = roundUp(bitsLeft, bitsPerChar) / bitsPerChar;

        if (charsToRead > this.buffer.capacity()) {
            charsToRead = this.buffer.capacity();
        }

        this.buffer.clear();
        this.buffer.limit(charsToRead);

        final int actualCharactersRead = this.in.read(this.buffer);

        if (actualCharactersRead == -1) {
            final int bytesRead;

            if (this.bitIndex > 0) {
                if (this.unfinishedByte != 0) {
                    buffer[offset] = this.unfinishedByte;
                    bytesRead = 1;
                } else {
                    bytesRead = -1;
                }

                this.unfinishedByte = 0;
                this.bitIndex = 0;
            } else {
                bytesRead = -1;
            }

            return bytesRead;
        }

        int bytesRead = 0;

        for (int i = 0; i < actualCharactersRead; i++) {
            final int charIndex = this.encoding.index(this.buffer.get(i));

            for (int j = this.encoding.bitBlockSize() - 1; j >= 0; j--) {
                if (getBit(charIndex, j)) {
                    this.unfinishedByte |= 1 << (Byte.SIZE - this.bitIndex - 1);
                }

                this.bitIndex++;

                if (this.bitIndex >= Byte.SIZE) {
                    buffer[offset + bytesRead] = this.unfinishedByte;
                    this.unfinishedByte = 0;
                    this.bitIndex = 0;

                    bytesRead++;
                }
            }
        }

        return bytesRead;
    }

    @Override
    public void close() throws IOException {
        this.in.close();
    }

    private static int roundUp(int num, int roundTo) {
        final int mod = num % roundTo;
        if (mod != 0) {
            num += roundTo - mod;
        }

        return num;
    }

    private static boolean getBit(int v, int n) {
        return (v & (1 << n)) != 0;
    }
}
