# -*- coding: utf-8 -*

import wave
import sys
import struct
import array
import itertools

#sprawdź, czy unsigned short ma 2 bajty i czy jest little endian   
assert struct.pack('h', 1) == '\x01\x00'  
  
#sprawdź czy long ma 4 bajty
assert struct.calcsize('L') == 4

#długość bloku
BLOCK_LEN = 1000


def wave_read(filename):    
    """Wczytaj plik WAV z pliku. 
    
    string -> [[int]]
    """
    
    sound = wave.open(filename, 'rb')    
    (nch, samplew, afreq, nsamples, comp, compname) = sound.getparams()
    
    assert (
        nch == 2 and
        samplew == 2 and
        afreq == 44100 and
        comp == 'NONE' and
        compname == 'not compressed'
    )
    
    bytes = sound.readframes(nsamples)    
    sound.close()
    
    assert len(bytes) == 4 * nsamples
    
    samples = array.array('h', bytes)
    
    left_channel = list(samples[0::2])
    right_channel = list(samples[1::2])
    
    return [left_channel, right_channel]
    

def wave_write(filename, left_channel, right_channel):
    """Zapisz kanały jako plik WAV.
    
    (string, [int], [int]) -> void
    """
    
    
    nsamples = len(left_channel) * 2
    channels = [0] * len(left_channel) * 2
    channels[0::2] = left_channel
    channels[1::2] = right_channel
    samples = array.array('h', channels)
    
    sound = wave.open(filename, 'wb')
    sound.setparams((2, 2, 44100, nsamples, 'NONE', 'not compressed'))
    sound.writeframes(samples.tostring())
    sound.close()


def channel_to_blocks(channel):
    """Podziel kanał na listę bloków długości BLOCK_LEN.

    [int] -> [[int]]
    """

    pos = 0
    blocks = []

    while pos < len(channel):
        blocks.append(channel[pos:pos+BLOCK_LEN])
        pos += BLOCK_LEN

    return blocks


def blocks_to_channel(blocks):    
    """Połącz listę bloków w jedną listę (kanał).
    
    [[int]] -> [int]
    """
    
    
    channel = []
    
    for block in blocks:
        channel.extend(block)
    
    return channel


def pred_encode(block):
    """Zakoduj blok przy użyciu kodowanie predykcyjnego. \
    Zwraca parę postaci (zakodowana lista, typ predyktora). 
    
    [int] -> ([int], int)    
    """


    def pred_encode_type0(block):
        return list(block)
    def pred_encode_type1(block):
        res = list(block)
        for i in range(1, len(block)):
            res[i] -= block[i-1]
        return res
    def pred_encode_type2(block):
        res = list(block)
        for i in range(2, len(block)):
            res[i] -= 2 * block[i-1] - block[i-2]
        return res
    def pred_encode_type3(block):
        res = list(block)
        for i in range(3, len(block)):
            res[i] -= 3 * block[i-1] - 3 * block[i-2] + block[i-3]
        return res
        
    encs = [
        pred_encode_type0(block),
        pred_encode_type1(block),
        pred_encode_type2(block),
        pred_encode_type3(block)
    ]
    sums = [sum(map(abs, enc)) for enc in encs]
    encs_index = zip(encs, range(4))
    tuples = zip(sums, encs_index)

    return sorted(tuples)[0][1]  

def pred_decode(block):
    """Odkoduj blok zakodowany predykcyjnie. \
    Blok to para (zakodowana lista, typ predyktora).
    
    ([int], int) -> [int]
    """
    
    def pred_decode_type0(block):
        return list(block)
    def pred_decode_type1(block):
        res = list(block)
        for i in range(1, len(block)):
            res[i] += res[i-1]
        return res
    def pred_decode_type2(block):
        res = list(block)
        for i in range(2, len(block)):
            res[i] += 2 * res[i-1] - res[i-2]
        return res
    def pred_decode_type3(block):
        res = list(block)
        for i in range(3, len(block)):
            res[i] += 3 * res[i-1] - 3 * res[i-2] + res[i-3]
        return res
    
    errors, type = block
    decs = [
        pred_decode_type0,
        pred_decode_type1,
        pred_decode_type2,
        pred_decode_type3
    ]

    return decs[type](errors)
    

def rice_encode(block):
    """Zakoduj blok używając kodów Rice'a. Na wejściu para \
    (lista zakodowana predykcyjnie, typ predyktora). Na wyjściu czwórka \
    (lista zakodowana kodami Rice'a, typ kodu Rice'a, \
    długość kodowania w bitach, typ predyktora). 
    
    [int], int -> array.array, int, int, int    
    """
    
    block, pt = block
    block = [-2 * val - 1 if val < 0 else 2 * val for val in block]    

    codecosts = [0] * 16
    for q in range(16):
        codecosts[q] = sum([(val >> q) for val in block]) + (1 + q) * len(block)

    cost, bestq = sorted(zip(codecosts, range(16)))[0]    

    bitstack = array.array('H', '\x00\x00') * ((cost - 1) / 16 + 3)
    stackpos = 0

    mask = (1 << bestq) - 1
    ones16 = 0xffff
    shifts = [val >> bestq for val in block]
    rems = [((val & mask) << 1) + 1 for val in block]

    for shift, rem in zip(shifts, rems):
        stackpos += shift
        ind = stackpos >> 4
        indoff = stackpos & 15
        
        bitstack[ind] |= ones16 & (rem << indoff)
        bitstack[ind+1] |= ones16 & (rem >> 16 - indoff)
        
        stackpos += bestq + 1

    return bitstack, bestq, cost, pt

    
def rice_decode(block):
    """Dekoduj blok zakodowany przy użyciu kodów Rice'a.
    
    (array.array, int, int, int) -> ([int], int)
    """


    bits, q, cost, pt = block
    res = []
    stackpos = 0
    mask = (1 << q) - 1
    while stackpos < cost:
        shift = 0        
        while bits[stackpos >> 4] & (1 << (stackpos & 15)) == 0:
            stackpos += 1
            shift += 1
        stackpos += 1
        ind = stackpos >> 4
        indoff = stackpos & 15
        
        rem = (bits[ind] >> indoff) & mask
        rem |= (bits[ind+1] << (16 - indoff)) & mask 
        
        res.append((shift << q) + rem)
        stackpos += q

    res = [(val + 1) / -2 if val & 1 else val / 2 for val in res]
    return res, pt

def byte_encode(block):
    """Konwertuj blok zakodowany przy użyciu kodów Rice na ciąg bajtów.

    (array.array, int, int, int) -> string
    """

    bits, q, cost, pt = block
    header = pt + (q << 2) + (cost << 6)

    assert header < 2 ** 32
    return struct.pack('L', header) + bits.tostring()
    
def byte_decode(block):
    """Konwertuj ciąg bajtów na blok zakodowany przy użyciu kodów Rice'a.

    string -> (array.array, int, int, int)
    """

    header = struct.unpack('L', block[:4])[0]
    pt, q, cost = header & 3, (header >> 2) & 15, header >> 6
    bits = array.array('H')
    bits.fromstring(block[4:])

    return bits, q, cost, pt
    
    
def encode_block(block):
    """Przeprowadź pełne kodowanie bloku.

    [int] -> string
    """


    block = pred_encode(block)
    block = rice_encode(block)
    block = byte_encode(block)
    return block
    
def decode_block(block):
    """Przeprowadź pełne dekodowanie bloku.
    
    string -> [int]
    """


    block = byte_decode(block)
    block = rice_decode(block)
    block = pred_decode(block)    
    return block
    
def encode_main(filename_in, filename_out):
    """Koduj plik filename_in i zapisz w pliku filename_out.
    
    (string, string) -> void
    """

    l_ch, r_ch = wave_read(filename_in)
    l_bl = channel_to_blocks(l_ch)
    r_bl = channel_to_blocks(r_ch)
    l_bl_en = map(encode_block, l_bl)
    r_bl_en = map(encode_block, r_bl)
    
    interleaved = [0] * len(l_bl) * 2
    interleaved[0::2] = l_bl_en
    interleaved[1::2] = r_bl_en
    
    output = open(filename_out, 'wb')
    output.write(''.join(interleaved))
    output.close()
    
def decode_main(filename_in, filename_out):
    """Dekoduj plik filename_in i zapisz w pliku filename_out.
    
    (string, string) -> void
    """

    input = open(filename_in, 'rb')
    content = input.read()
    input.close()
    
    pos = 0
    blocks = []
    while pos < len(content):
        cost = struct.unpack('L', content[pos:pos+4])[0] >> 6
        framelen = 4 + ((cost - 1) / 16 + 3) * 2
        blocks.append(content[pos:pos+framelen])
        pos += framelen
    
    l_bl_en = blocks[0::2]
    r_bl_en = blocks[1::2]
    
    l_bl = map(decode_block, l_bl_en)
    r_bl = map(decode_block, r_bl_en)
    l_ch = blocks_to_channel(l_bl)
    r_ch = blocks_to_channel(r_bl)
    
    wave_write(filename_out, l_ch, r_ch)

"""
if __name__ == '__main__':
    #Żeby zakodować plik WAV w liście parametrów \
    #należy podać nazwę pliku wejściowego i wyjściowego.
    #
    #Żeby odkodować należy dodać w dowolnym miejscu literę 'd'.
    #
    #UWAGA: w obu przypadkach plik wyjściowy zostanie NADPISANY \
    #bez pytania o potwierdzenie!
    
    encode = True
    
    ar = sys.argv
    assert len(ar) in [3, 4]
    if len(ar) == 4:
        assert 'd' in ar
        encode = False
        ar.remove('d')
    
    if encode:
        encode_main(*ar[1:])
    else:
        decode_main(*ar[1:])
"""  
 

