from matplotlib.numerix import *
from numarray import *
from pylab import plot, legend, axis, xlabel, text, show
Error.setMode(all=None, overflow='warn', underflow='ignore', dividebyzero='warn', invalid='warn')

def normal(x, mean, var) :
    return (1.0/sqrt(2*pi*var))*e**(-((x-mean)**2)/(2*var))

m = 0.0
a = 40.0
b = 3.0
x = 13.7

y = arange(-150, 150, 1)
mean = ((a**2)*x + (b**2)*m)/(a**2 + b**2)
var = ((a*b)**2)/(a**2 + b**2)
p_y_given_x = normal(y, mean, var)
p_y = normal(y,m,a**2 + b**2)

plot(y,p_y_given_x)
plot(y,p_y)
legend(('p(y|x)', 'p(y)'))
axis([y[0],y[-1],0,0.15])
xlabel('y')
text(-100,0.12, r'$\mu=0$')
text(-100,0.11, r'$\alpha=40$')
text(-100,0.1, r'$\beta=3$')
show()

