# -*- coding: utf-8 -*-
"""
Python code for basis FIR filter design
@author: Matti Pastell <matti.pastell@helsinki.fi>
http://mpastell.com
"""

import sys
from pylab import *
import scipy.signal as signal

#Plot frequency and phase response
def mfreqz(b,a=1):
    w,h = signal.freqz(b,a)
    h_dB = 20 * log10 (abs(h))
    subplot(211)
    plot(w/max(w),h_dB)
    ylim(-150, 5)
    ylabel('Magnitude (db)')
    xlabel(r'Normalized Frequency (x$\pi$rad/sample)')
    title(r'Frequency response')
    subplot(212)
    h_Phase = unwrap(arctan2(imag(h),real(h)))
    plot(w/max(w),h_Phase)
    ylabel('Phase (radians)')
    xlabel(r'Normalized Frequency (x$\pi$rad/sample)')
    title(r'Phase response')
    subplots_adjust(hspace=0.5)

#Plot step and impulse response
def impz(b,a=1):
    l = len(b)
    impulse = repeat(0.,l); impulse[0] =1.
    x = arange(0,l)
    response = signal.lfilter(b,a,impulse)
    subplot(211)
    stem(x, response)
    ylabel('Amplitude')
    xlabel(r'n (samples)')
    title(r'Impulse response')
    subplot(212)
    step = cumsum(response)
    stem(x, step)
    ylabel('Amplitude')
    xlabel(r'n (samples)')
    title(r'Step response')
    subplots_adjust(hspace=0.5)


Fe = 5000.
Fniq = Fe/2
transition = 1.001

n = 32
Ffond = 500
Fharm1 = 1500.
Fother = 850.

def bandpass(n, freq):
    global Fe
    a = signal.firwin(n, cutoff = (freq/Fniq)/transition, window = 'boxcar')
    b = - signal.firwin(n, cutoff = (freq/Fniq)*transition, window = 'boxcar')
    b[n/2] = b[n/2] + 1
    d = - (a+b); d[n/2] = d[n/2] + 1
    w,h = signal.freqz(d,1)
    return d/max(abs(h))

def c_dump(fond, harm1, other):
    f = fond * 32768
    f = map(int, f)
    h = harm1 * 32768
    h = map(int, h)
    o = other * 32768
    o = map(int, o)
    print "/* Apply the 3 filters on the bitfield. The results are stored in"
    print " * fil_fond, fil_harm1 and fil_other. */"
    print "static void apply_filters(void)"
    print "{"
    print "	fil_fond = 0;"
    print "	fil_harm1 = 0;"
    print "	fil_other = 0;"
    print
    for i in range(len(f)):
        print "	if (bitfield & (1UL << %d)) {"%(i)
        print "		fil_fond += %d;"%(f[i])
        print "		fil_harm1 += %d;"%(h[i])
        print "		fil_other += %d;"%(o[i])
        print "	}"
        print
    print "#ifdef HOST_VERSION"
    print "	fil_fond = saturate16(fil_fond);"
    print "	fil_harm1 = saturate16(fil_harm1);"
    print "	fil_other = saturate16(fil_other);"
    print "#endif"
    print "}"


figure(3)
fond = bandpass(n, Ffond)
harm1 = bandpass(n, Fharm1)
other = bandpass(n, Fother)
c_dump(fond, harm1, other)

#Frequency response
mfreqz(fond)
mfreqz(harm1)
mfreqz(other)

if len(sys.argv) >= 2:
    savefig(sys.argv[1])
else:
    show()
