import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import seaborn as sns
import matplotlib
import sys
sys.path.append('../python')
import warnings
warnings.filterwarnings('ignore')
from ctw import CTW
# fixed random seed
np.random.seed(0)
save_figs = True
fig_dir = '../../results/figs/'
This notebook explores the measurement of causal influence between two ternary jointly Markov random processes, $\{X_i\}_{i=1}^n$ and $\{Y_i\}_{i=1}^n$. We'll look at the following three scenarios:
First, we'll define 4 state transition matrices. First, for the independent case, we define the probabilities of transitioning from $x_i$ to $x_{i+1}$ and $y_i$ to $y_{i+1}$. In the unidirectional case, we can reuse the transition matrix for independent $Y$ and define probabilities for transitioning from $(x_i,y_i)$ to $(x_{i+1},\cdot)$. Lastly, for the bidirectional case, we define the probabilities of transitioning from $(x_i,y_i)$ to $(\cdot,y_{i+1})$ to use with the matrix from $(x_i,y_i)$ to $(x_{i+1},\cdot)$.
x2x = {
'0':[3/9,5/9,1/9],
'1' :[1/9,6/9,2/9],
'2' :[1/9,4/9,4/9] }
y2y = {
'0':[3/9,5/9,1/9],
'1' :[1/9,7/9,1/9],
'2' :[1/9,5/9,3/9] }
yind = [1/9,5/9,3/9]
xy2x = {
'(0,0)':[5/9,3/9,1/9],
'(1,0)' :[3/9,5/9,1/9],
'(2,0)' :[2/9,5/9,2/9],
'(0,1)':[5/9,3/9,1/9],
'(1,1)' :[1/9,7/9,1/9],
'(2,1)' :[1/9,5/9,3/9],
'(0,2)':[3/9,5/9,1/9],
'(1,2)' :[1/9,6/9,2/9],
'(2,2)' :[1/9,4/9,4/9] }
xy2y = {
'(0,0)':[5/9,3/9,1/9],
'(1,0)' :[7/9,1/9,1/9],
'(2,0)' :[2/9,6/9,1/9],
'(0,1)':[3/9,5/9,1/9],
'(1,1)' :[3/9,5/9,1/9],
'(2,1)' :[1/9,5/9,3/9],
'(0,2)':[6/9,2/9,1/9],
'(1,2)' :[1/9,5/9,3/9],
'(2,2)' :[1/9,4/9,4/9] }
def probs(x,y,mode,reverse):
if reverse:
x_ = x
x = y
y = x_
if mode == 'indep':
xprobs = x2x[str(x)]
yprobs = y2y[str(y)]
elif mode == 'unidir':
xprobs = xy2x['(%i,%i)'%(x,y)]
yprobs = yind
else:
xprobs = xy2x['(%i,%i)'%(x,y)]
yprobs = xy2y['(%i,%i)'%(x,y)]
if reverse:
return yprobs,xprobs
else:
return xprobs,yprobs
Using these probabilities, we define 3 functions for sampling the next state $(x_{i+1},y_{i+1})$ given a previous state $(x_i,y_i)$ and another function to generate $n$ samples using one of the models.
def indpedent_sample(x_,y_):
x = np.argmax(np.random.multinomial(1,x2x[str(x_)]))
y = np.argmax(np.random.multinomial(1,y2y[str(y_)]))
return x,y
def unidirectional_sample(x_,y_):
x = np.argmax(np.random.multinomial(1,xy2x['(%i,%i)'%(x_,y_)]))
y = np.argmax(np.random.multinomial(1,yind))
return x,y
def bidirectional_sample(x_,y_):
x = np.argmax(np.random.multinomial(1,xy2x['(%i,%i)'%(x_,y_)]))
y = np.argmax(np.random.multinomial(1,xy2y['(%i,%i)'%(x_,y_)]))
return x,y
def gen_seqs(n=500,mode='indep'):
if mode == 'indep':
sampler = indpedent_sample
elif mode == 'unidir':
sampler = unidirectional_sample
elif mode == 'bidir':
sampler = bidirectional_sample
else:
raise TypeError("mode must be 'indep', 'unidir' or 'bidir'")
x = 1
y = 1
xs = []
ys = []
for i in range(n):
x,y = sampler(x,y)
xs.append(x)
ys.append(y)
return xs,ys
def kl(p,q):
tot = 0
for x in range(len(p)):
tot += p[x]*np.log2(p[x]/q[x])
return tot
def unidirectional_measure(xs,ys):
cyx = []
for i,(x,y) in enumerate(zip(xs,ys)):
fc = xy2x['(%i,%i)'%(x,y)]
xmat = np.asarray([xy2x['(%i,%i)'%(x,i)] for i in range(3)])
fr = np.dot(xmat.T,yind)
cyx.append(kl(fc,fr))
return cyx
def bidirectional_measure(xs,ys):
cyx = []
cxy = []
xprobs = np.asarray([1/3,1/3,1/3])
yprobs = np.asarray([1/3,1/3,1/3])
for i,(x,y) in enumerate(zip(xs,ys)):
# get the complete distribution for x and y
fcx = xy2x['(%i,%i)'%(x,y)]
fcy = xy2y['(%i,%i)'%(x,y)]
# matrices for updating probabilities of hidden x and y
xmat = np.asarray([xy2x['(%i,%i)'%(i,y)] for i in range(3)])
ymat = np.asarray([xy2y['(%i,%i)'%(x,i)] for i in range(3)])
# update hidden probabilities
xprobs = np.dot(xmat.T,xprobs)
yprobs = np.dot(ymat.T,yprobs)
# matrices for computing restricted distributions from hidden probabilities
xmat = np.asarray([xy2x['(%i,%i)'%(x,i)] for i in range(3)])
ymat = np.asarray([xy2y['(%i,%i)'%(i,y)] for i in range(3)])
# get the restricted distributions
frx = np.dot(xmat.T,yprobs)
fry = np.dot(ymat.T,xprobs)
# update causal measures
cyx.append(kl(fcx,frx))
cxy.append(kl(fcy,fry))
return cyx,cxy
def true_measure(xs,ys,mode,reverse):
# Initialization
cxy = []
px0,py0 = probs(1,1,mode,reverse)
y0 = ys[0]
pyr = []
for y1 in range(3):
py1 = 0
for x0 in range(3):
xprobs,yprobs = probs(x0,y0,mode,reverse)
py1 += yprobs[y1]*px0[x0]
pyr.append(py1)
_,pyc = probs(x0,y0,mode,reverse)
cxy.append(kl(pyc,pyr))
pxp = []
for x0 in range(3):
xprobs,yprobs = probs(x0,y0,mode,reverse)
pxp.append(yprobs[y0]*px0[x0])
norm = sum(pxp)
pxp = [px/norm for px in pxp]
pyr = []
for y2 in range(3):
py2 = 0
for x1 in range(3):
_,yprobs = probs(x1,y1,mode,reverse)
p_ = 0
for x0 in range(3):
xprobs,_ = probs(x0,y0,mode,reverse)
p_ += xprobs[x1]*pxp[x0]
py2 += yprobs[y2]*p_
pyr.append(py2)
x1 = xs[1]
_,pyc = probs(x1,y1,mode,reverse)
cxy.append(kl(pyc,pyr))
for t in range(2,len(xs)):
# update posterior
yt2 = ys[t-2]
pxp_ = []
for xt1 in range(3):
p_ = 0
for xt2 in range(3):
xprobs,_ = probs(xt2,yt2,mode,reverse)
p_ += xprobs[xt1]*pxp[xt2]
pxp_.append(p_)
yt1 = ys[t-1]
yt = ys[t]
pxp = []
for xt1 in range(3):
_,yprobs = probs(xt1,yt1,mode,reverse)
pxp.append(yprobs[yt]*pxp_[xt1])
norm = sum(pxp)
pxp = [px/norm for px in pxp]
# compute restricted distribution
pyr = []
for ytp1 in range(3):
py = 0
for xt in range(3):
_,yprobs = probs(xt,yt,mode,reverse)
p_ = 0
for xt1 in range(3):
xprobs,_ = probs(xt1,yt1,mode,reverse)
p_ += xprobs[xt]*pxp[xt1]
py += yprobs[ytp1]*p_
pyr.append(py)
# add new causal measure
xt = xs[t]
_,pyc = probs(xt,yt,mode,reverse)
cxy.append(kl(pyc,pyr))
return cxy
def unidirectional_partial_measure(xs,ys):
k = 1
cyx = []
for i,(x,y) in enumerate(zip(xs[k:],ys[k:])):
fc = xy2x['(%i,%i)'%(x,y)]
yprobs = np.asarray([0,0,0])
yprobs[ys[i]] = 1
for l in range(k):
ymat = np.asarray([y2y['%i'%(j)] for j in range(3)])
yprobs = np.dot(ymat.T,yprobs)
xmat = np.asarray([xy2x['(%i,%i)'%(x,j)] for j in range(3)])
fr = np.dot(xmat.T,yprobs)
cyx.append(kl(fc,fr))
return cyx
def bidirectional_partial_measure(xs,ys):
k = 1
cyx = []
cxy = []
for i,(x,y) in enumerate(zip(xs[k:],ys[k:])):
fcx = xy2x['(%i,%i)'%(x,y)]
fcy = xy2y['(%i,%i)'%(x,y)]
xprobs = np.asarray([0,0,0])
xprobs[xs[i]] = 1
yprobs = np.asarray([0,0,0])
yprobs[ys[i]] = 1
for l in range(k):
xmat = np.asarray([xy2x['(%i,%i)'%(j,ys[i+l])] for j in range(3)])
ymat = np.asarray([xy2y['(%i,%i)'%(xs[i+l],j)] for j in range(3)])
xprobs = np.dot(xmat.T,xprobs)
yprobs = np.dot(ymat.T,yprobs)
xmat = np.asarray([xy2x['(%i,%i)'%(x,j)] for j in range(3)])
frx = np.dot(xmat.T,yprobs)
cyx.append(kl(fcx,frx))
ymat = np.asarray([xy2y['(%i,%i)'%(j,y)] for j in range(3)])
fry = np.dot(ymat.T,xprobs)
cxy.append(kl(fcy,fry))
return cyx,cxy
def causality_regret(n,Knorm):
mc = 10+((9*2)/2)*np.log2(n/9) + 9*2
mr = 3*np.log2(n/3) +3*(3/2+np.log2(3))-1/2
cr = mc + mr + (1/np.sqrt(2) * Knorm * np.sqrt(mc))
return cr
def pcausality_regret(n,Knorm):
mc = 10+((9*2)/2)*np.log2(n/9) + 9*2
mr = 31+((27*2)/2)*np.log2(n/27) + 27*2
cr = mc + mr + (1/np.sqrt(2) * Knorm * np.sqrt(mc))
return cr
fig,[ax1,ax2] = plt.subplots(1,2,figsize=(15,5))
n=20000
for Knorm in [1,10,50,100,500,1000]:
cr = [causality_regret(i,Knorm)/i for i in range(1,n+1)]
ax1.plot(cr,'--',label='Knorm=%i'%Knorm);
ax1.legend()
ax1.set_xlim([0,n])
ax1.set_ylim([0,0.5]);
for Knorm in [1,10,50,100,500,1000]:
cr = [pcausality_regret(i,Knorm)/i for i in range(1,n+1)]
ax2.plot(cr,'--',label='Knorm=%i'%Knorm);
ax2.legend()
ax2.set_xlim([0,n])
ax2.set_ylim([0,0.5]);
font = {'family' : 'serif',
'size' : 20}
matplotlib.rc('font', **font)
n = 10000
x,y = gen_seqs(n=n,mode='indep')
ctwxc = CTW(depth=1,symbols=3,sidesymbols=3)
pxcs = ctwxc.predict_sequence(x,sideseq=y)
ctwxr = CTW(depth=1,symbols=3)
pxrs = ctwxr.predict_sequence(x)
cyx = [kl(pxcs[:,i],pxrs[:,i]) for i in range(pxcs.shape[1])]
Knorms = []
Ks = []
for i in range(pxcs.shape[1]):
Ki = 0
for j in range(3):
Ki += np.abs(np.log2(pxcs[j,i]/pxrs[j,i]))
Ks.append(Ki)
Knorms.append(np.linalg.norm(Ks))
cryx = [causality_regret(i,Knorms[i-1])/i for i in range(1,len(cyx)+1)]
ctwyc = CTW(depth=1,symbols=3,sidesymbols=3)
pycs = ctwyc.predict_sequence(y,sideseq=x)
ctwyr = CTW(depth=1,symbols=3)
pyrs = ctwyr.predict_sequence(y)
cxy = [kl(pycs[:,i],pyrs[:,i]) for i in range(pycs.shape[1])]
Knorms = []
Ks = []
for i in range(pycs.shape[1]):
Ki = 0
for j in range(3):
Ki += np.abs(np.log2(pycs[j,i]/pyrs[j,i]))
Ks.append(Ki)
Knorms.append(np.linalg.norm(Ks))
crxy = [causality_regret(i,Knorms[i-1])/i for i in range(1,len(cxy)+1)]
fig, [ax1,ax2] = plt.subplots(2,1,figsize=(15,14))
ax1.plot(cyx,c='salmon',label=r'$\hat{C}_{Y\rightarrow X}(n)$')
ax1.plot(cxy,c='c',label=r'$\hat{C}_{X\rightarrow Y}(n)$')
ax1.set_xlim([0,n])
ax1.set_ylim([0,0.4])
ax1.legend()
#ax1.set_title('Causal Measure')
ax1.grid('on')
ax2.plot(np.cumsum(cyx)/np.arange(1,len(cyx)+1),c='salmon',label=r'$n^{-1}CR_{Y\rightarrow X}(n)$')
ax2.plot(np.cumsum(cxy)/np.arange(1,len(cxy)+1),c='c',label=r'$n^{-1}CR_{X\rightarrow Y}(n)$')
ax2.plot(cryx,'--',c='salmon',label=r'$n^{-1}CR_{Y\rightarrow X}(n)$ Bound')
ax2.plot(crxy,'--',c='c',label=r'$n^{-1}CR_{X\rightarrow Y}(n)$ Bound')
ax2.legend(loc=1)
#ax2.set_title('Normalized Causal Regret & Bounds')
ax2.set_xlim([0,n])
ax2.set_ylim([0,0.1])
ax2.grid('on');
plt.tight_layout()
if save_figs:
plt.savefig(fig_dir+'Fig2.png',dpi=200)
font = {'family' : 'serif',
'size' : 20}
matplotlib.rc('font', **font)
d = 1
n = 10000
boundstart = 0 #n//2
x,y = gen_seqs(n=n,mode='unidir')
true_cyx = true_measure(y,x,mode='unidir',reverse=True)
ctwxc = CTW(depth=1,symbols=3,sidesymbols=3)
pxcs = ctwxc.predict_sequence(x,sideseq=y)
ctwxr = CTW(depth=d,symbols=3)
pxrs = ctwxr.predict_sequence(x)
cyx = [kl(pxcs[:,i+d-1],pxrs[:,i]) for i in range(pxcs.shape[1])]
Knorms = []
Ks = []
for i in range(pxcs.shape[1]):
Ki = 0
for j in range(3):
Ki += np.abs(np.log2(pxcs[j,i]/pxrs[j,i]))
Ks.append(Ki)
Knorms.append(np.linalg.norm(Ks))
cryx = [causality_regret(i,Knorms[i-1])/i for i in range(1,len(cyx)+1)]
ctwyc = CTW(depth=1,symbols=3,sidesymbols=3)
pycs = ctwyc.predict_sequence(y,sideseq=x)
ctwyr = CTW(depth=d,symbols=3)
pyrs = ctwyr.predict_sequence(y)
cxy = [kl(pycs[:,i+d-1],pyrs[:,i]) for i in range(pycs.shape[1])]
Knorms = []
Ks = []
for i in range(pycs.shape[1]):
Ki = 0
for j in range(3):
Ki += np.abs(np.log2(pycs[j,i]/pyrs[j,i]))
Ks.append(Ki)
Knorms.append(np.linalg.norm(Ks))
crxy = [causality_regret(i,Knorms[i-1])/i for i in range(1,len(cxy)+1)]
fig, [ax1,ax2] = plt.subplots(2,1,figsize=(15,14))
ax1.plot(cyx,c='salmon',label=r'$\hat{C}_{Y\rightarrow X}(n)$')
ax1.plot(cxy,c='c',label=r'$\hat{C}_{X\rightarrow Y}(n)$')
ax1.plot(true_cyx[d-1:-1],'o',alpha=0.9,c='salmon',label=r'$C_{Y\rightarrow X}(n)$')
ax1.set_xlim([n-100,n-1])
ax1.set_ylim([0,0.3])
ax1.legend()
#ax1.set_title('Estimate of Causal Measure')
ax1.grid('on')
ax2.plot(np.cumsum(np.abs(np.asarray(cyx)-np.asarray(true_cyx[d-1:-1])))/np.arange(1,len(cyx)+1),
c='salmon',label=r'$n^{-1}CR_{Y\rightarrow X}(n)$')
ax2.plot(np.cumsum(cxy)/np.arange(1,len(cxy)+1),c='c',label=r'$n^{-1}CR_{X\rightarrow Y}(n)$')
ax2.plot(cryx,'--',c='salmon',label=r'$n^{-1}CR_{Y\rightarrow X}(n)$ Bound')
ax2.plot(crxy,'--',c='c',label=r'$n^{-1}CR_{X\rightarrow Y}(n)$ Bound')
ax2.legend()
#ax2.set_title('Normalized Causal Regret')
ax2.set_xlim([0,n])
ax2.set_ylim([0,0.5])
ax2.grid('on');
plt.tight_layout()
if save_figs:
plt.savefig(fig_dir+'Fig3.png',dpi=200)
d = 1
n = 50000
x,y = gen_seqs(n=n,mode='bidir')
true_cxy = true_measure(x,y,mode='bidir',reverse=False)
true_cyx = true_measure(y,x,mode='bidir',reverse=True)
ctwxc = CTW(depth=1,symbols=3,sidesymbols=3)
pxcs = ctwxc.predict_sequence(x,sideseq=y)
ctwxr = CTW(depth=d,symbols=3)
pxrs = ctwxr.predict_sequence(x)
cyx = [kl(pxcs[:,i+d-1],pxrs[:,i]) for i in range(pxrs.shape[1])]
Knorms = []
Ks = []
for i in range(pxcs.shape[1]):
Ki = 0
for j in range(3):
Ki += np.abs(np.log2(pxcs[j,i]/pxrs[j,i]))
Ks.append(Ki)
Knorms.append(np.linalg.norm(Ks))
cryx = [causality_regret(i,Knorms[i-1])/i for i in range(1,len(cyx)+1)]
ctwyc = CTW(depth=1,symbols=3,sidesymbols=3)
pycs = ctwyc.predict_sequence(y,sideseq=x)
ctwyr = CTW(depth=d,symbols=3)
pyrs = ctwyr.predict_sequence(y)
cxy = [kl(pycs[:,i+d-1],pyrs[:,i]) for i in range(pyrs.shape[1])]
Knorms = []
Ks = []
for i in range(pycs.shape[1]):
Ki = 0
for j in range(3):
Ki += np.abs(np.log2(pycs[j,i]/pyrs[j,i]))
Ks.append(Ki)
Knorms.append(np.linalg.norm(Ks))
crxy = [causality_regret(i,Knorms[i-1])/i for i in range(1,len(cxy)+1)]
fig, [ax1,ax2] = plt.subplots(2,1,figsize=(15,14))
ax1.plot(cyx,c='salmon',label=r'$\hat{C}_{Y\rightarrow X}(n)$')
ax1.plot(cxy,c='c',label=r'$\hat{C}_{X\rightarrow Y}(n)$')
ax1.plot(true_cyx[d-1:-1],'o',alpha=0.9,c='salmon',label=r'$C_{Y\rightarrow X}(n)$')
ax1.plot(true_cxy[d-1:-1],'o',alpha=0.9,c='c',label=r'$C_{X\rightarrow Y}(n)$')
ax1.set_xlim([n-100,n-1])
ax1.set_ylim([0,1])
ax1.legend(loc=1)
#ax1.set_title('Estimate of Causal Measure')
ax1.grid('on')
ax2.plot(np.cumsum(np.abs(np.asarray(cyx)-np.asarray(true_cyx[d-1:-1])))/np.arange(1,len(cyx)+1),
c='salmon',label=r'$n^{-1}CR_{Y\rightarrow X}(n)$')
ax2.plot(np.cumsum(np.abs(np.asarray(cxy)-np.asarray(true_cxy[d-1:-1])))/np.arange(1,len(cyx)+1),
c='c',label=r'$n^{-1}CR_{X\rightarrow Y}(n)$')
ax2.plot(cryx,'--',c='salmon',label=r'$n^{-1}CR_{Y\rightarrow X}(n)$ Bound')
ax2.plot(crxy,'--',c='c',label=r'$n^{-1}CR_{X\rightarrow Y}(n)$ Bound')
ax2.legend(loc=1)
#ax2.set_title('Normalized Causal Regret')
ax2.set_xlim([0,n])
ax2.set_ylim([0,0.1])
ax2.grid('on');
plt.tight_layout()
if save_figs:
plt.savefig(fig_dir+'Fig4.png',dpi=200)
d = 2
#n = 100000
#x,y = gen_seqs(n=n,mode='bidir')
true_cyx,true_cxy = bidirectional_partial_measure(x,y)
ctwxc = CTW(depth=1,symbols=3,sidesymbols=3)
pxcs = ctwxc.predict_sequence(x,sideseq=y)
ctwxr = CTW(depth=d,symbols=3,sidesymbols=3,staleness=d-1)
pxrs = ctwxr.predict_sequence(x,sideseq=y)
cyx = [kl(pxcs[:,i+d-1],pxrs[:,i]) for i in range(pxrs.shape[1])]
Knorms = []
Ks = []
for i in range(pxrs.shape[1]):
Ki = 0
for j in range(3):
Ki += np.abs(np.log2(pxcs[j,i+d-1]/pxrs[j,i]))
Ks.append(Ki)
Knorms.append(np.linalg.norm(Ks))
cryx = [pcausality_regret(i,Knorms[i-1])/i for i in range(1,len(cyx)+1)]
ctwyc = CTW(depth=1,symbols=3,sidesymbols=3)
pycs = ctwyc.predict_sequence(y,sideseq=x)
ctwyr = CTW(depth=d,symbols=3,sidesymbols=3,staleness=d-1)
pyrs = ctwyr.predict_sequence(y,sideseq=x)
cxy = [kl(pycs[:,i+d-1],pyrs[:,i]) for i in range(pyrs.shape[1])]
Knorms = []
Ks = []
for i in range(pyrs.shape[1]):
Ki = 0
for j in range(3):
Ki += np.abs(np.log2(pycs[j,i+d-1]/pyrs[j,i]))
Ks.append(Ki)
Knorms.append(np.linalg.norm(Ks))
crxy = [pcausality_regret(i,Knorms[i-1])/i for i in range(1,len(cxy)+1)]
fig, [ax1,ax2] = plt.subplots(2,1,figsize=(15,14))
ax1.plot(cyx,c='salmon',label=r'$\hat{C}^{(%i)}_{Y\rightarrow X}(n)$'%(d-1))
ax1.plot(cxy,c='c',label=r'$\hat{C}^{(%i)}_{X\rightarrow Y}(n)$'%(d-1))
ax1.plot(true_cyx[:-1],'o',alpha=0.8,c='salmon',label=r'$C^{(%i)}_{Y\rightarrow X}(n)$'%(d-1))
ax1.plot(true_cxy[:-1],'o',alpha=0.8,c='c',label=r'$C^{(%i)}_{X\rightarrow Y}(n)$'%(d-1))
ax1.set_xlim([n-100,n])
ax1.set_ylim([0,1])
ax1.legend(loc=1)
#ax1.set_title('Estimate of Partial Causal Measure')
ax1.grid('on')
ax2.plot(np.cumsum(np.abs(np.asarray(cyx)-np.asarray(true_cyx[:-1])))/np.arange(1,len(cyx)+1),
c='salmon',label=r'$n^{-1}CR_{Y\rightarrow X}(n)$')
ax2.plot(np.cumsum(np.abs(np.asarray(cxy)-np.asarray(true_cxy[:-1])))/np.arange(1,len(cyx)+1),
c='c',label=r'$n^{-1}CR_{X\rightarrow Y}(n)$')
ax2.plot(cryx,'--',c='salmon',label=r'$n^{-1}CR_{Y\rightarrow X}(n)$ Bound')
ax2.plot(crxy,'--',c='c',label=r'$n^{-1}CR_{X\rightarrow Y}(n)$ Bound')
ax2.legend(loc=1)
#ax2.set_title('Normalized Causal Regret')
ax2.set_xlim([0,n])
ax2.set_ylim([0,0.1])
ax2.grid('on');
plt.tight_layout()
if save_figs:
plt.savefig(fig_dir+'Fig5.png',dpi=200)