PINN for Phonon BTE

PINN for Phonon BTE

> Li, R., Lee, E., & Luo, T. (2021). Physics-informed neural networks for solving multiscale mode-resolved phonon Boltzmann transport equation. Materials Today Physics, 19, 100429.

https://github.com/RuiyangLi6/PINN-pBTE/tree/main/BTE_2D_Square

作者开源了他的PINN程序,不过实在是有些不太好读,在读的时候总想起哈尔的移动城堡..

img

这篇文章一方面是翻译一下这些code,另一方面也是介绍一下采用确定性方法是怎么求解声子BTE的,以及用NN解声子BTE的Loss是如何构建的. 采用随机性方法(蒙特卡洛)求解BTE可以参考之前写过的蒙特卡洛中的频率抽样和动力学蒙特卡洛. 这篇文章里定义的函数完全采用了作者开源的程序,并没有进行修改.

求解体系

system

作者算了一些不同维度的case,这里拿上面这个2维体系作为例子,把两侧边界调整成周期性边界,上方热沉边界取窄再附加上漫反射边界条件凑成热流边界就变成了热扩展问题的典型结构.

控制方程

稳态SMRT近似下energy-deviation-based phonon BTE具有以下的形式, \[ v_g \boldsymbol{s} \cdot \nabla e=\frac{e^{e q}-e}{\tau} \] 对立体角和频率积分,可以得到能量守恒方程, \[ \nabla \cdot \boldsymbol{q}=\sum_{p} \int_{0}^{\omega_{\max , p}} \int_{4 \pi} \frac{e^{e q}-e}{4\pi \tau} d \Omega d \omega=0 \] 其中热流的表达式为(这里作者的文章里有个typo,没有除以\(4\pi\),不过不影响结果) \[ \boldsymbol{q}=\sum_{p} \int_{0}^{\omega_{\max , p}} \int_{4 \pi} \frac{\boldsymbol{v} e}{4\pi} d \Omega d \omega \] 联立BTE和能量守恒方程并给出合适的边界条件,我们就可以求解出分布函数的分布,并最终得到体系的温度分布. 在具体介绍PINN如何离散并求解它们之前,需要先弄清楚分布函数是哪些量的函数,从而才能理解我们的input向量究竟是什么形式的.

\(e(\boldsymbol{x}, \boldsymbol{s}, k, p) = \hbar\omega D(\omega, p)\left[ f(\boldsymbol{x}, \boldsymbol{s}, k, p) - f^{BE}(T_{ref})\right]\)是声子能量偏差函数,\(e^{eq}(k, p, T) = \hbar\omega D(\omega, p)\left[f^{BE}(T) - f^{BE}(T_{ref})\right]\)是平衡声子能量偏差函数. 在选定了\(T_{ref}\)以后,\(f^{BE}(T_{ref})\)就是确定的,求解\(e\)就等价于求解\(f\). 之所以定义这样一个能量偏差函数,是因为这样做之后平衡分布\(e^{eq}\)的可以做出一定近似,从而简化了求解. \[ f^{BE}(T) = \frac{1}{\exp[\hbar\omega/ (k_BT)]}, \quad D(\omega, p)=\frac{k^2}{2\pi^2 v_g} \] \(D(\omega, p)\)为声子态密度,色散关系描述了\(\omega\)\(k\)之间的依赖关系,同时给出了\(\omega_m\)频率上限,根据色散关系我们可以求出群速度\(v_g=\partial \omega / \partial k\). 在求解phonon BTE之前,弛豫时间和色散关系需要我们预先确定好,所以当色散关系给定以后,\(D(\omega, p)\)这一项我们是完全已知的. 声子支的含义是不同支的色散关系和弛豫时间是不一样的,比如下面这张图分别是Si的横支声子(TA)色散和纵支声子(LA)色散. 由于光学支群速度很小,对传热贡献不大,所以在声子BTE的求解过程中,我们一般只考虑三个声学,即一个LA和两个TA. 假设我们一个声子支离散成10段,那么三个声子支就需要离散成30段,考虑更多的声子支的处理方式都是一样的. \(D(\omega, p)\)里面的\(p\),就来自于不同声子支的色散关系的区别.

Dispersion

在稳态下声子分布函数\(f\)是空间坐标\(\boldsymbol{x}\),方向向量\(\boldsymbol{s}(\cos\theta, \sin\theta\cos\phi, \sin\theta\sin\phi)\),波数\(k\)(等价于角频率\(\omega = \omega(k, p)\),他们之间的关系由给定的声子色散已经确定了),声子支\(p\)的函数. 也就是说如果要采用确定性方法来求解\(f\)的话,我们需要建立一个多维数组来描述\(f\)的分布,最外层是空间分布,空间上的每一个点上都储存着一个三维数组,相当于空间中的每个点都储存着一个球,球半径方向的距离代表波矢大小\(k\),球面的方向为\(\boldsymbol{s}\). 如果我们在空间上离散了\(Nx\)个点,\(\theta\)方向离散了\(N_1\)个点,\(\phi\)方向离散了\(N_2\)个点(也就是把整个球面立体角离散了\(N_s = N_1 \times N_2\)个点),声子波矢离散\(N_k\)个点(\(N_k\)已经包含了各个声子支的离散),那么最终描述\(f\)的数组长度就是\(N_x\times N_s \times N_k\).

discrete

小温差近似

在小温差近似下,\(e^{eq}\)可以近似为 \[ e^{eq}(k, p, T) = \hbar\omega D(\omega, p)\left[f^{BE}(T) - f^{BE}(T_{ref})\right] \approx C(\omega, p)\left(T-T_{r e f}\right) \] 其中 \[ C(\omega, p) = \hbar\omega D(\omega, p) \frac{\partial f^{BE}}{\partial T} \] 这个近似在蒙特卡洛模拟中计算温度时也广泛用到,这个简化的核心是实现了频率和温度的解耦,我们是基于能量守恒方程来获得体系的温度分布的. 从上面的能量守恒方程中可以看到,我们需要对整个声子频谱做积分. 如果频率和温度是耦合的,不同温度下这个积分的数值都需要重新计算,想要从积分值反推回温度也是一个很复杂的非线性过程. 而一旦频率和温度解耦了,我们就相当于把温度从积分中提出来了,不同温度下这个积分值都是一样的,从而简化了求解.

把这个近似代入到能量守恒方程中,可以得到 \[ (T - T_{ref}) \sum_p \int_0^{\omega_m} \frac{4\pi C(\omega, p) }{\tau}\mathrm{d}\omega = \sum_p \int_0^{\omega_m} \int_{4\pi} \frac{e }{\tau} \mathrm{d}\Omega \mathrm{d}\omega \] 于是温度就可以表示成, \[ T - T_{ref} = \frac{1}{4\pi} \left(\sum_p \int_0^{\omega_m} \int_{4\pi} \frac{e}{\tau} v_g \mathrm{d}\Omega \mathrm{d}k \right)\times \left(\sum_p \int_0^{\omega_m} \frac{C(\omega, p) }{\tau}v_g\mathrm{d}k\right)^{-1} \] 多出来的群速度是把频率积分转换成波矢积分时产生的.

PINN的网格离散

下面的程序源自于作者提供的mesh_2d.py文件,空间网格采用Sobol sequence离散450个点,文章这里做了一个非均匀的空间离散,靠近上面热源的地方多撒了些点. 立体角采用高斯-勒让德插值离散\(12\times 12\)个网格,离散对象为\(\cos\theta\in [-1, 1]\)\(\phi \in [0, \pi]\),下面这个TwoD_train_mesh(Nx, Nt, N1, N2, Nk)这几个输入参数分别是Nx空间点的总数,Nt大概数决定了上面的点的密度,N1是\(\cos\theta\)离散的网格数量,N2是\(\phi\)离散的网格数量,Nk是一个声子支离散的网格数量,作者对每一支离散了10个网格. param(k, p, T)是采用文章中指定的色散模型计算对应\(k, p\)处的v(群速度),tau(弛豫时间),C(比热),p=0是TA支,p=1是LA支. TwoD_vt(k,Tr)输入离散后的波矢\(k\)数组和参考温度,输出TA支和LA支的声子自由程. TwoD_test_mesh(Nx,N1,N2,Nk)这个和TwoD_train_mesh差不多,只不过是生成测试集的离散网格,这时候不需要对上面做加密.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
# File: mesh_2d.py
import numpy as np
import torch
hbar = 1
hkB = (1.054572/1.380649)*1e2
a = 5.431 #Angstrom

def TwoD_train_mesh(Nx,Nt,N1,N2,Nk):
soboleng = torch.quasirandom.SobolEngine(dimension=2)
mesh = soboleng.draw(Nt*2).numpy()
x1 = mesh[:,0]
y1 = mesh[:,1]
x1 = x1[y1>=0.5].reshape(-1,1)
y1 = y1[y1>=0.5].reshape(-1,1)*0.4+0.6
mesh = soboleng.draw(int((Nx-Nt)/0.8)).numpy()
x2 = mesh[:,0]
y2 = mesh[:,1]
x2 = x2[y2<0.8].reshape(-1,1)
y2 = y2[y2<0.8].reshape(-1,1)
x = np.vstack((x1,x2))
y = np.vstack((y1,y2))

mu, w1 = np.polynomial.legendre.leggauss(N1)
phi, w2 = np.polynomial.legendre.leggauss(N2)
phi = (phi + 1) * np.pi/2

sin_theta = np.sqrt(1 - mu**2)
sin_theta[np.arange(0,N1,2)] = -sin_theta[np.arange(0,N1,2)]

mu = np.tile(mu.reshape(-1,1),(1,N2)).reshape(-1,1)
sin_theta = np.tile(sin_theta.reshape(-1,1),(1,N2)).reshape(-1,1)
phi = np.tile(phi.reshape(-1,1),(N1,1))
eta = sin_theta * np.cos(phi)

w1 = np.tile(w1.reshape(-1,1),(1,N2)).reshape(-1,1)
w2 = np.tile(w2.reshape(-1,1),(N1,1))
w = w1*w2*np.pi

k = np.linspace(0,1,Nk*2+1)[1:Nk*2+1].reshape(-1,1)
k = k[np.arange(0,Nk*2-1,2)]*np.pi*2/a

return x,y,mu,eta,w,k

# use reference temperature
def param(k,p,T):
c10 = 5.23 #TA 1e13 A/s
c20 = -2.26
c11 = 9.01 #LA
c21 = -2.0
Ai = 1.498e7
BL = 1.18e2
BT = 8.708
BU = 2.89e8

c1 = (c11-c10)*p + c10
c2 = (c21-c20)*p + c20
om = c1*k + c2*k**2 # unit 1e13 1/s
v = c1 + 2*c2*k # unit 1e13 A/s

step = np.heaviside(k-np.pi/a,1)

ti = Ai*om**4
t1 = BL*om**2*T**3
t01 = BT*om*T**4
t02 = BU*om**2/np.sinh(hkB*om/T)
t0 = t01*(1-step) + t02*step
tNU = (t1-t0)*p + t0
tau = 1/(ti+tNU) # s

dfeq = hkB*om/(T**2)*np.exp(hkB*om/T)/(np.exp(hkB*om/T)-1)**2
D = k**2/(2*np.pi**2*v) # 1e-13 s/A^3
C = hbar*om*D*dfeq # Js/A^3/K

return v,tau,C

def TwoD_vt(k,Tr):
v0,tau0,_ = param(k,np.zeros_like(k),Tr)
v1,tau1,_ = param(k,np.ones_like(k),Tr)
vt0 = np.log10(v0*tau0*1e11)
vt1 = np.log10(v1*tau1*1e11)

return vt0,vt1

def TwoD_test_mesh(Nx,N1,N2,Nk):
px = py = np.linspace(0,1,Nx)
x,y = np.meshgrid(px,py)
x = x.reshape(-1,1)
y = y.reshape(-1,1)

mu, w1 = np.polynomial.legendre.leggauss(N1)
phi, w2 = np.polynomial.legendre.leggauss(N2)
phi = (phi + 1) * np.pi/2

sin_theta = np.sqrt(1 - mu**2)
sin_theta[np.arange(0,N1,2)] = -sin_theta[np.arange(0,N1,2)]

mu = np.tile(mu.reshape(-1,1),(1,N2)).reshape(-1,1)
sin_theta = np.tile(sin_theta.reshape(-1,1),(1,N2)).reshape(-1,1)
phi = np.tile(phi.reshape(-1,1),(N1,1))
eta = sin_theta * np.cos(phi)

w1 = np.tile(w1.reshape(-1,1),(1,N2)).reshape(-1,1)
w2 = np.tile(w2.reshape(-1,1),(N1,1))
w = w1*w2*np.pi

k = np.linspace(0,1,Nk*2+1)[1:Nk*2+1].reshape(-1,1)
k = k[np.arange(0,Nk*2-1,2)]*np.pi*2/a

return x,y,mu,eta,w,k

看一下离散出来的点的分布,

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
# 用于观察空间离散点的形状
def remove_ticks(ax):
ax.tick_params(axis=u'both', which=u'both',length=0)
for spine in ['right', 'top', 'left', 'bottom']:
ax.spines[spine].set_visible(False)

soboleng = torch.quasirandom.SobolEngine(dimension=2)
mesh = soboleng.draw(200*2).numpy()
x1 = mesh[:,0]
y1 = mesh[:,1]
x1 = x1[y1>=0.5].reshape(-1,1)
y1 = y1[y1>=0.5].reshape(-1,1)*0.4+0.6
mesh = soboleng.draw(int((450-200)/0.8)).numpy()
x2 = mesh[:,0]
y2 = mesh[:,1]
x2 = x2[y2<0.8].reshape(-1,1)
y2 = y2[y2<0.8].reshape(-1,1)
x = np.vstack((x1,x2))
y = np.vstack((y1,y2))

plt.scatter(x, y, s=10, marker='o')
ax = plt.gca()
ax.tick_params(axis=u'both', which=u'both',length=0)

plt.xlim(0, 1)
plt.ylim(0, 1)
ax.set_aspect('equal', adjustable='box')

miaomiao

传统数值方法求解PDE需要依赖网格间差分来近似梯度项,但是NN直接建立了input和output的关系,直接对NN求梯度就可以了,网格点之间也没有直接的相互作用,因此也不需要建立规则的网格. 但是NN预测的效果很依赖于在体系温度梯度大撒上更多的点,而且每换一个体系、换一个边界条件或者换一个材料就需要重新train一个NN,想用它来处理某些实际问题我不太觉得是一个很好的选择.. 在反问题的求解上也是一个麻烦的方案.

NN搭建

这篇文章train了两个NN,Net1用来预测平衡部分\(e^{eq}\),Net0用来预测非平衡部分\(e - e^{eq}\),从小温差近似可以看到 \[ e^{e q}(k, p, T) / C(\omega, p) \approx \left(T-T_{r e f}\right) \] \(e_{eq}\)除以对应模态的比热得到的就是温度,而温度只和空间分布有关,\(C(\omega, p)\)又是我们已知的,因此用来预测平衡部分的NN的输入没有必要是\((\boldsymbol{x}, \boldsymbol{s}, k, p)\)这么高维度的向量,只是空间坐标的函数就可以了. 在得到了预测值后,我们可以再乘上对应模态比热还原回去. 我们可以再用体系的特征温差\(\mathrm{d}T\)把输出调整到无量纲的形式(\(T_{h} - T_{ref} = T_{ref} - T_{c} = \mathrm{d}T\)). 此时Net1的输入和输出就是 \[ \operatorname{Net1}(x, y, L) = \frac{e^{e q}(k, p, T)}{C(\omega, p)\mathrm{d}T} \approx \frac{T-T_{r e f}}{\mathrm{d}T} \] L是体系的特征尺寸,我们希望训练好的NN能够预测同一体系不同特征尺寸下的结果,比如希望网络能够预测10nm ~ 100um的结果,那么输入的L就可以选为[0, 1, 2, 3, 4],每一个数值对应一个对数后的特征尺度. 我们让Net0和Net1的输出量纲是一样的,这时比较好对他们统一操作,于是Net0的输入和输出就是 \[ \operatorname{Net0}(x, y, \cos\theta, \phi, k, MFP, p, L) = \frac{e^{neq}(x, y, \cos\theta, \phi, k, p)}{C(\omega, p)\mathrm{d}T} = \frac{e(x, y, \cos\theta, \phi, k, p) - e^{eq}(k, p, T)}{C(\omega, p)\mathrm{d}T} \] 这个MFP就是对应\(k\)处的自由程,\(MFP = v_g \times \tau\). (我不知道为什么这里要输入MFP,或许是作者尝试后把它也当作input训练效果比较好吧)在有了Net0和Net1以后,我们就可以得到\(e=e^{neq} + e^{eq}\)了,通过\(e\)就可以定义内部节点和边界节点处的Loss. 由于声子的自由程跨度很大且体系的特征尺寸也是跨量级变化的,因此输入到网络中的\(MFP\)实际上是\(\log(MFP\times 10^{11})\). 为了让输出也保持在相同量级上,不至于让各层权重变化太大,作者的实际网络输出的值在上述的原始表达式上又除以了$ $.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# File: model.py
import torch
import numpy as np
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset,RandomSampler

class Net(nn.Module):
def __init__(self, input_n, NL, NN):
super(Net, self).__init__()
self.input_layer = nn.Linear(input_n, NN)
self.hidden_layers = nn.ModuleList([nn.Linear(NN, NN) for i in range(NL-1)])
self.output_layer = nn.Linear(NN, 1)

def forward(self, x):
o = self.act(self.input_layer(x))
for i,li in enumerate(self.hidden_layers):
o = self.act(li(o))
output = self.output_layer(o)
return output

def act(self, x):
return x * torch.sigmoid(x)

Loss构建

Loss有三个部分.

一是各频率的声子需要满足玻尔兹曼方程, \[ \begin{aligned} & \Vert \mathbf{v_g} \cdot \nabla e - \frac{e^{neq}}{\tau} \Vert^2 = 0 \\ \Rightarrow & \Vert v_g \times (\cos\theta, \sin\theta\cos\phi) \cdot (\frac{\partial e}{\partial (Lx^*)}, \frac{\partial e}{\partial (Ly^*)}) - \frac{e^{neq}}{\tau} \Vert^2 = 0 \\ \Rightarrow & \Vert \cos\theta \frac{\partial e}{\partial x^*} + \sin\theta\cos\phi \frac{\partial e}{\partial y^*} - \frac{e^{neq}}{\text{MFP}/L} \Vert^2 = 0 \end{aligned} \] 下面程序里的e0 = net0(e0_in)*(10**(vt0-L))*dTe1 = net0(e1_in)*(10**(vt1-L))*dT分别是TA支和LA支的非平衡部分,即\(e^{neq} / C(\omega, p)\)eEq = net1(torch.cat((x,L),1))*dT是平衡部分,即\(e^{eq} / C(\omega, p)\). 后面的

1
2
3
4
e0_x = torch.autograd.grad(e0+eEq,x,grad_outputs=torch.ones_like(x).to(device),create_graph=True)[0]
e1_x = torch.autograd.grad(e1+eEq,x,grad_outputs=torch.ones_like(x).to(device),create_graph=True)[0]
e0_y = torch.autograd.grad(e0+eEq,y,grad_outputs=torch.ones_like(y).to(device),create_graph=True)[0]
e1_y = torch.autograd.grad(e1+eEq,y,grad_outputs=torch.ones_like(y).to(device),create_graph=True)[0]
就是分别求\(e/C(\omega, p)\)对x和y的偏导. 于是就可以定义Loss1和Loss2了:
1
2
loss_1 = ((mu*e0_x+eta*e0_y) + e0/(10**vt0/10**L))/dT
loss_2 = ((mu*e1_x+eta*e1_y) + e1/(10**vt1/10**L))/dT
mueta分别是离散的\(\cos\theta\)\(\sin\theta\sin\phi\)数组.



二是声子能量守恒方程, \[ T - T_{ref} = \frac{1}{4\pi} \left(\sum_p \int_0^{\omega_m} \int_{4\pi} \frac{e}{\tau} v_g \mathrm{d}\Omega \mathrm{d}k \right)\times \left(\sum_p \int_0^{\omega_m} \frac{C(\omega, p) }{\tau}v_g\mathrm{d}k\right)^{-1} \] 下面程序里的

1
sum_e = torch.matmul(e.reshape(-1,Ns), w).reshape(-1,1).to(device)
就是相当于完成了第一个括号里面对立体角的积分,因为声子色散、弛豫时间这些参数我们都是认为各向同性的,所以里面的积分可以先做掉.

delta_T就是按照上述公式计算的得到的体系温差

1
2
deltaT = torch.matmul(sum_e.reshape(-1,Nk*Np),C*wk/tau*v/(4*np.pi)).reshape(-1,1)\
.repeat(1,Nk*Ns).reshape(-1,1)/torch.sum(C/tau*wk*v)

而Net1的输出是 \[ \operatorname{Net1}(x, y, L) = \frac{e^{e q}(k, p, T)}{C(\omega, p)\mathrm{d}T} \approx \frac{T-T_{r e f}}{\mathrm{d}T} \] 所以我们让eEq = net1(torch.cat((x,y,L),1))*dT,就可以定义这部分的Loss了,

1
loss_3 = (deltaT - eEq)/dT
当然,能量守恒方程有两种表示方式,他们之间是相互等价的,

\[ \nabla \cdot \boldsymbol{q}=\sum_{p} \int_{0}^{\omega_{\max , p}} \int_{4 \pi} \frac{e^{e q}-e}{4\pi \tau} d \Omega d \omega=0 \] 其中热流的表达式为 \[ \boldsymbol{q}=\sum_{p} \int_{0}^{\omega_{\max , p}} \int_{4 \pi} \frac{\boldsymbol{v_g} e}{4\pi } d \Omega d \omega \] 我们可以直接代入热流,求散度,也可以得到一个能量守恒的Loss, \[ \nabla \cdot \boldsymbol{q}=\sum_{p} \int_{0}^{\omega_{\max , p}} \int_{4 \pi} \left( \frac{\partial e}{\partial x}v_g \cos\theta + \frac{\partial e}{\partial y}v_g\sin\theta\sin\phi \right)\frac{v_g}{4\pi} d \Omega d k \]

1
2
3
sum_ex = torch.matmul(e_x.reshape(-1,Ns), w*mu[0:Ns].reshape(-1,1)).reshape(-1,1)
sum_ey = torch.matmul(e_y.reshape(-1,Ns), w*eta[0:Ns].reshape(-1,1)).reshape(-1,1)
dq = torch.matmul((sum_ex+sum_ey).reshape(-1,Nk*Np),C*wk*v**2/(4*np.pi)).reshape(-1,1)
上面这个w是作者对立体角采用高斯-勒让德差值采样的节点系数,这个sum_ex和sum_ey就相当于是对立体角做积分,dq就是上面这个热流表达式,乘了一个C(omega, p)是因为作者定义的\(e\)变量实际上是\(e/C(\omega, p)\). 这个TC相当于对不同尺度的数值做一下调整.
1
loss_4 = (dq/TC)



第三部分是边界条件,等温边界吸收所有射向它的声子,并以边界温度的平衡分布向体系内部发射声子, \[ e(\mathbf{x_b}, \vec{s}, k, p) = e^{eq}(k, p, T_b), \vec{s} \cdot \vec{n_b} > 0 \] 等价于 \[ e(\mathbf{x_b}, \vec{s}, k, p) / C(\omega, p) = T_b - T_{ref}, \vec{s} \cdot \vec{n_b} > 0 \] 程序下面的ec0和ec1代表冷边界分布,eh0和eh1代表热边界分布,ec1和eh1是在冷热边界交界处附近额外增加了一些点,可以看Main Loop中的程序. 文章里假设\(T_h - T_{ref} = T_{ref} - T_b = \mathrm{d}T = 0.5\). 根据前面对NN的input和output的讨论可以看到,程序里的ec和eh实际代表着 \[ \frac{e}{C(\omega, p)\mathrm{d}T} \] 于是等式就变成了 \[ \frac{e}{C(\omega, p)\mathrm{d}T} = \frac{T_b - T_{ref}}{\mathrm{d}T} = \pm 1 \] 要求\(\vec{s} \cdot \vec{n_b} > 0\),只选择一半的立体角数组代入边界条件就可以了. 于是就可以定义边界条件的Loss了,

1
2
3
4
loss_5 = (ec1 + 1)
loss_6 = (ec2 + 1)
loss_7 = (eh1 - 1)
loss_8 = (eh2 - 1)

散射边界的表达式为 \[ e\left(\boldsymbol{x}_{b}, \mathbf{s}, k, p\right)=\frac{1}{\pi} \int_{\mathbf{s}^{\prime} \cdot \boldsymbol{n}_{b}<0} e\left(\boldsymbol{x}_{b}, \mathbf{s}^{\prime}, k, p\right)\left|\mathbf{s}^{\prime} \cdot \boldsymbol{n}_{b}\right| d \Omega, s \cdot \boldsymbol{n}_{b}>0 \] 周期性边界的表达式为 \[ e\left(\boldsymbol{x}_{b_{1}}, \boldsymbol{s}, k, p\right)-e^{e q}\left(k, p, T_{b_{1}}\right)=e\left(\boldsymbol{x}_{b_{2}}, \boldsymbol{s}, k, p\right)-e^{e q}\left(k, p, T_{b_{2}}\right) \] 都可以用类似的方式来处理,至此我们得到了用NN训练Phonon BTE所有的Loss.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
# File: bte_train.py
import torch
import numpy as np
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset,RandomSampler
import time
from mesh_2d import *
from model import *

def bte_train(x,y,mu,eta,L,w,k,vt0,vt1,xb,yb,kb,vt0b,vt1b,Lb,Ns,Nk,Np,Nw,logL,Tr,dT,batchsize,learning_rate,epochs,path,device):
dataset1 = TensorDataset(torch.Tenso(x),torch.Tensor(y),torch.Tensor(L))
dataloader1 = DataLoader(dataset1,batch_size=batchsize[0],shuffle=True,num_workers=0,drop_last=True)
dataset2 = TensorDataset(torch.Tensor(xb),torch.Tensor(kb),torch.Tensor(vt0b),torch.Tensor(vt1b),torch.Tensor(Lb))
dataloader2 = DataLoader(dataset2,rbatch_size=batchsize[1],shuffle=True,num_workers=0,drop_last=True)

################################################################
net0 = Net(8, 8, 30).to(device)
net1 = Net(3, 8, 30).to(device)

optimizer0 = optim.Adam(net0.parameters(), lr=learning_rate, betas=(0.9,0.99), eps=1e-15)
optimizer1 = optim.Adam(net1.parameters(), lr=learning_rate, betas=(0.9,0.99), eps=1e-15)

def init_normal(m):
if type(m) == nn.Linear:
nn.init.kaiming_normal_(m.weight)

net0.apply(init_normal)
net0.train()
net1.apply(init_normal)
net1.train()

############################################################################

def criterion(x_in,y_in,L_in,xb,kb,vt0b,vt1b,Lb):
TC = (1/3)*torch.sum(C*v**3*tau*wk)/(dT*4)*1e11/(10**L_in).to(device)
x = x_in.repeat(1,Ns*Nk).reshape(-1,1).to(device)
y = y_in.repeat(1,Ns*Nk).reshape(-1,1).to(device)
L = L_in.repeat(1,Ns*Nk).reshape(-1,1).to(device)

x.requires_grad = True
y.requires_grad = True

######### Interior points ##########
e0_in = torch.cat((x,y,mu,eta,k,vt0,L,torch.zeros_like(x)),1).to(device)
e1_in = torch.cat((x,y,mu,eta,k,vt1,L,torch.ones_like(x)),1).to(device)
e0 = net0(e0_in)*dT*(10**vt0)/(10**L)
e1 = net0(e1_in)*dT*(10**vt1)/(10**L)
eEq = net1(torch.cat((x,y,L),1))*dT

e0_x = torch.autograd.grad(e0+eEq,x,grad_outputs=torch.ones_like(x).to(device),create_graph=True)[0]
e1_x = torch.autograd.grad(e1+eEq,x,grad_outputs=torch.ones_like(x).to(device),create_graph=True)[0]
e0_y = torch.autograd.grad(e0+eEq,y,grad_outputs=torch.ones_like(y).to(device),create_graph=True)[0]
e1_y = torch.autograd.grad(e1+eEq,y,grad_outputs=torch.ones_like(y).to(device),create_graph=True)[0]

e = torch.cat(((e0+eEq).reshape(-1,Ns*Nk),(e0+eEq).reshape(-1,Ns*Nk),(e1+eEq).reshape(-1,Ns*Nk)),1).reshape(-1,1)
e_x = torch.cat((e0_x.reshape(-1,Ns*Nk),e0_x.reshape(-1,Ns*Nk),e1_x.reshape(-1,Ns*Nk)),1).reshape(-1,1)
e_y = torch.cat((e0_y.reshape(-1,Ns*Nk),e0_y.reshape(-1,Ns*Nk),e1_y.reshape(-1,Ns*Nk)),1).reshape(-1,1)
sum_e = torch.matmul(e.reshape(-1,Ns), w).reshape(-1,1).to(device)
deltaT = torch.matmul(sum_e.reshape(-1,Nk*Np),C*wk/tau*v/(4*np.pi)).reshape(-1,1).repeat(1,Nk*Ns).reshape(-1,1)/torch.sum(C/tau*wk*v)

sum_ex = torch.matmul(e_x.reshape(-1,Ns), w*mu[0:Ns].reshape(-1,1)).reshape(-1,1)
sum_ey = torch.matmul(e_y.reshape(-1,Ns), w*eta[0:Ns].reshape(-1,1)).reshape(-1,1)
dq = torch.matmul((sum_ex+sum_ey).reshape(-1,Nk*Np),C*wk*v**2/(4*np.pi)).reshape(-1,1)

######### Isothermal boundary ##########
xb = xb.repeat(1,int(Ns/2)).reshape(-1,1).to(device)
kb = kb.repeat(1,int(Ns/2)).reshape(-1,1).to(device)
Lb = Lb.repeat(1,int(Ns/2)).reshape(-1,1).to(device)
vt0b = vt0b.repeat(1,int(Ns/2)).reshape(-1,1).to(device)
vt1b = vt1b.repeat(1,int(Ns/2)).reshape(-1,1).to(device)
pc = torch.cat((torch.zeros(Nb*3,1),torch.ones(Nb*3,1)),0).to(device)
ph = torch.cat((torch.zeros(Nb,1),torch.ones(Nb,1)),0).to(device)

x1 = torch.cat((torch.ones_like(xb),xb*0.9),1)
x0 = torch.cat((torch.zeros_like(xb),xb*0.9),1)
y1 = torch.cat((xb*0.8+0.1,torch.ones_like(xb)),1)
y0 = torch.cat((xb,torch.zeros_like(xb)),1)

cb = torch.cat((y0,x0,x1),0).repeat(2,1)
cEq = net1(torch.cat((cb,Lb.repeat(6,1)),1))
vtc = torch.cat((vt0b.repeat(3,1),vt1b.repeat(3,1)),0) # v*tau for cold boundary
c_in = torch.cat((cb,sb,kb.repeat(6,1),vtc,Lb.repeat(6,1),pc),1)
ec1 = net0(c_in)*(10**vtc)/(10**Lb.repeat(6,1)) + cEq

hEq = net1(torch.cat((y1.repeat(2,1),Lb.repeat(2,1)),1))
vth = torch.cat((vt0b,vt1b),0)
h_in = torch.cat((y1.repeat(2,1),sy1,kb.repeat(2,1),vth,Lb.repeat(2,1),ph),1)
eh1 = net0(h_in)*(10**vth)/(10**Lb.repeat(2,1)) + hEq

# output of corner points
ec2 = net0(c0_in)*(10**vtw)/(10**Lw) + net1(c1_in)
eh2 = net0(h0_in)*(10**vtw)/(10**Lw) + net1(h1_in)

######### Loss ##########
loss_1 = ((mu*e0_x+eta*e0_y) + e0/(10**vt0/10**L))/dT
loss_2 = ((mu*e1_x+eta*e1_y) + e1/(10**vt1/10**L))/dT
loss_3 = (deltaT - eEq)/dT
loss_4 = (dq/TC)
loss_5 = (ec1 + 1)
loss_6 = (ec2 + 1)
loss_7 = (eh1 - 1)
loss_8 = (eh2 - 1)

##############
# MSE LOSS
loss_f = nn.MSELoss()

loss1 = loss_f(loss_1,torch.zeros_like(loss_1))
loss2 = loss_f(loss_2,torch.zeros_like(loss_2))
loss3 = loss_f(loss_3,torch.zeros_like(loss_3))
loss4 = loss_f(loss_4,torch.zeros_like(loss_4))
loss5 = loss_f(loss_5,torch.zeros_like(loss_5))
loss6 = loss_f(loss_6,torch.zeros_like(loss_6))
loss7 = loss_f(loss_7,torch.zeros_like(loss_7))
loss8 = loss_f(loss_8,torch.zeros_like(loss_8))

return loss1, loss2, loss3, loss4, loss5, loss6, loss7, loss8

###################################################################

# Main loop
Loss_min = 100
Loss_list = []
tic = time.time()

wk = np.pi*2/a/Nk
p = np.vstack((np.zeros((Nk,1)),np.zeros((Nk,1)),np.ones((Nk,1))))
v,tau,C = param(np.tile(k,(Np,1)),p,Tr)
v = torch.FloatTensor(v).to(device)
tau = torch.FloatTensor(tau).to(device)
C = torch.FloatTensor(C).to(device)

Nb = int(Ns/2)*batchsize[1]
s = np.hstack((mu,eta))
mu = torch.FloatTensor(mu).repeat(batchsize[0]*Nk,1).to(device)
eta = torch.FloatTensor(eta).repeat(batchsize[0]*Nk,1).to(device)
w = torch.FloatTensor(w).to(device)

# Solid angles for the points at the boundary
sx0 = np.tile(s[s[:,0]>0],(batchsize[1],1))
sx1 = np.tile(s[s[:,0]<0],(batchsize[1],1))
sy0 = np.tile(s[s[:,1]>0],(batchsize[1],1))
sy1 = np.tile(s[s[:,1]<0],(batchsize[1],1))
sb = torch.FloatTensor(np.vstack((sy0,sx0,sx1))).repeat(2,1).to(device) # for cold boundary
sy1 = torch.FloatTensor(sy1).repeat(2,1).to(device) # for hot top boundary

k = k/(np.pi*2/a)
Nl = len(logL)
# extra corner points at the cold boundary (0.9 < Y < 1, X = 0 or X = 1)
# we need these addtional training points as the boundary temperature distribution is discontinuous near the top corner
yb = torch.FloatTensor(yb).repeat(1,int(Ns/2)*Nk*Nl).reshape(-1,1)
xc0 = torch.cat((torch.zeros_like(yb),yb*0.1+0.9),1)
xc1 = torch.cat((torch.ones_like(yb),yb*0.1+0.9),1)
sc0 = np.tile(s[s[:,0]>0],(Nk*Nl*Nw,1))
sc1 = np.tile(s[s[:,0]<0],(Nk*Nl*Nw,1))
sc = torch.FloatTensor(np.concatenate((sc0,sc1),0)).repeat(2,1).to(device)
xc = torch.cat((xc0,xc1),0).repeat(2,1).to(device)

# extra corner points at the hot top boundary (0 < X < 0.1 or 0.9 < X < 1, Y = 1)
yh0 = torch.cat((yb*0.1,torch.ones_like(yb)),1)
yh1 = torch.cat((yb*0.1+0.9,torch.ones_like(yb)),1)
sh = np.tile(s[s[:,1]<0],(Nk*Nl*Nw,1))
sh = torch.FloatTensor(sh).repeat(4,1).to(device)
yh = torch.cat((yh0,yh1),0).repeat(2,1).to(device)

# phonon quantities for corner points (v*tau, wave number, and polarization)
vt0w = torch.FloatTensor(vt0).repeat(1,int(Ns/2)).reshape(-1,1).repeat(2*Nw*Nl,1)
vt1w = torch.FloatTensor(vt1).repeat(1,int(Ns/2)).reshape(-1,1).repeat(2*Nw*Nl,1)
kw = torch.FloatTensor(k).repeat(1,int(Ns/2)).reshape(-1,1).repeat(4*Nw*Nl,1).to(device)
vtw = torch.cat((vt0w,vt1w),0).to(device)
pw = torch.cat((torch.zeros_like(vt0w),torch.ones_like(vt1w)),0).to(device)
Lw = torch.FloatTensor(logL).repeat(1,int(Ns/2)*Nk).reshape(-1,1).repeat(Nw*4,1).to(device)

# input to the network for these corner points
c0_in = torch.cat((xc,sc,kw,vtw,Lw,pw),1)
c1_in = torch.cat((xc,Lw),1)
h0_in = torch.cat((yh,sh,kw,vtw,Lw,pw),1)
h1_in = torch.cat((yh,Lw),1)

vt0 = torch.FloatTensor(vt0).repeat(1,Ns).reshape(-1,1).repeat(batchsize[0],1).to(device)
vt1 = torch.FloatTensor(vt1).repeat(1,Ns).reshape(-1,1).repeat(batchsize[0],1).to(device)
k = torch.FloatTensor(k).repeat(1,Ns).reshape(-1,1).repeat(batchsize[0],1).to(device)

for epoch in range(epochs):
Loss = []
for batch_idx, ((x_in,y_in,L_in),(xb,kb,vt0b,vt1b,Lb)) in enumerate(zip(dataloader1,dataloader2)):
net0.zero_grad()
net1.zero_grad()
loss1,loss2,loss3,loss4,loss5,loss6,loss7,loss8 = criterion(x_in,y_in,L_in,xb,kb,vt0b,vt1b,Lb)
loss = loss1 + loss2 + loss3 + loss4 + loss5 + loss6 + loss7 + loss8
loss.backward()
optimizer0.step()
optimizer1.step()
Loss.append(loss.item())
Loss_list.append([loss1.item(),loss2.item(),loss3.item(),loss4.item(),loss5.item(),loss6.item(),loss7.item(),loss8.item()])
if epoch%200 == 0:
print('Train Epoch: {} Loss: {:.4f} {:.4f} {:.4f} {:.4f} {:.4f}'.format(epoch,loss1.item(),loss5.item(),loss6.item(),loss7.item(),loss8.item()))
torch.save(net0.state_dict(),path+"train_ng_epoch"+str(epoch)+"e.pt")
Loss = np.array(Loss)
if np.mean(Loss) < Loss_min:
torch.save(net0.state_dict(),path+"model0.pt")
torch.save(net1.state_dict(),path+"model1.pt")
Loss_min = np.mean(Loss)

toc = time.time()
elapseTime = toc - tic
print("elapse time in parallel = ", elapseTime)
np.savetxt(path+'Loss.txt',np.array(Loss_list), fmt='%.6f')

def bte_test(x,y,mu,eta,w,k,vt0,vt1,Nx,Ns,Nk,Np,L,Tr,dT,index,path,device):
net0 = Net(8, 8, 30).to(device)
net1 = Net(3, 8, 30).to(device)

net0.load_state_dict(torch.load(path+"model0.pt",map_location=device))
net0.eval()

net1.load_state_dict(torch.load(path+"model1.pt",map_location=device))
net1.eval()

########################################
p = np.vstack((np.zeros((Nk,1)),np.zeros((Nk,1)),np.ones((Nk,1))))
v,tau,C = param(np.tile(k,(Np,1)),p,Tr)
v = torch.FloatTensor(v).to(device)
tau = torch.FloatTensor(tau).to(device)
C = torch.FloatTensor(C).to(device)

mu = torch.FloatTensor(mu).repeat(Nx*Nk,1).to(device)
eta = torch.FloatTensor(eta).repeat(Nx*Nk,1).to(device)
k = torch.FloatTensor(k/(np.pi*2/a)).repeat(1,Ns).reshape(-1,1).repeat(Nx,1).to(device)
vt0 = torch.FloatTensor(vt0).repeat(1,Ns).reshape(-1,1).repeat(Nx,1).to(device)
vt1 = torch.FloatTensor(vt1).repeat(1,Ns).reshape(-1,1).repeat(Nx,1).to(device)
w = torch.FloatTensor(w).to(device)
wk = np.pi*2/a/Nk

deltaT = np.zeros((Nx**2,len(L)))
tic = time.time()
for j in range(len(L)):
for i in range(Nx):
x1 = torch.FloatTensor(x[i*Nx:(i+1)*Nx]).repeat(1,Ns*Nk).reshape(-1,1).to(device)
y1 = torch.FloatTensor(y[i*Nx:(i+1)*Nx]).repeat(1,Ns*Nk).reshape(-1,1).to(device)
L1 = torch.FloatTensor(L[j]).repeat(Ns*Nk*Nx,1).to(device)

eEq = net1(torch.cat((x1,y1,L1),1))*dT
e0_in = torch.cat((x1,y1,mu,eta,k,vt0,L1,torch.zeros_like(x1)),1)
e1_in = torch.cat((x1,y1,mu,eta,k,vt1,L1,torch.ones_like(x1)),1)
e0 = net0(e0_in)*dT*(10**vt0)/(10**L1) + eEq
e1 = net0(e1_in)*dT*(10**vt1)/(10**L1) + eEq
e = torch.cat((e0.reshape(-1,Ns*Nk),e0.reshape(-1,Ns*Nk),e1.reshape(-1,Ns*Nk)),1).reshape(-1,1)

sum_e = torch.matmul(e.reshape(-1,Ns), w).reshape(-1,1)
T = torch.matmul(sum_e.reshape(-1,Nk*Np),C*wk/tau*v/(4*np.pi)).reshape(-1,1)/torch.sum(C/tau*wk*v)
deltaT[i*Nx:(i+1)*Nx,j] = np.squeeze(T.cpu().data.numpy())

np.savez(str(int(index))+'Square',x = x,y = y,T = (deltaT+dT)/(2*dT),L = L)
toc = time.time()
elapseTime = toc - tic
print ("elapse time = ", elapseTime)

主程序

1
python main.py

可以开始训练了. 模拟的体系变了之后,需要重新train所有的net.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
# main.py
import torch
import numpy as np
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset,RandomSampler
import time
import array
from pathlib import Path
import sys
import os
import matplotlib
from matplotlib import pyplot as plt
from matplotlib import cm
from matplotlib.colors import ListedColormap, LinearSegmentedColormap

from bte_train import bte_train, bte_test
from mesh_2d import *

# Please note that the current model is not perfect for such a problem with harsh boundary conditions (step function at the top corners)
# Training would be better if the temperature distributions are continuous at the boundaries
epochs = 5000
path = "./"
Nl = 4
Nw = 10 # Number of points near the top corner
logL = np.linspace(0, 3, Nl).reshape(-1,1)
batchsize = array.array('i', [120, 80])
batchnum = 15

############################################
Nx = 450
Nb = 30
Nk = 10
N1 = N2 = 12 # number of quadrature points
Np = 3
Ns = N1*N2

Tr = 300
dT = 0.5

x,y,mu,eta,w,k = TwoD_train_mesh(Nx,200,N1,N2,Nk) # nonuniform spatial mesh
x = np.tile(x,(1,Nl)).reshape(-1,1)
y = np.tile(y,(1,Nl)).reshape(-1,1)
L = np.tile(logL,(Nx,1))

# Since the boundary condition is discontinuous near the top corner (i.e., Th = 1, Tc = -1),
# we have divided the boudnary points into two parts, one portion is near the top corner, the other is away from the top corner
xb = np.linspace(0,1,Nb+2)[1:Nb+1].reshape(-1,1)
yb = np.linspace(0,1,Nw+2)[1:Nw+1].reshape(-1,1)

xb,kb,Lb = np.meshgrid(xb,k,logL)
xb = xb.reshape(-1,1)
kb = kb.reshape(-1,1)
Lb = Lb.reshape(-1,1)

vt0,vt1 = TwoD_vt(k,Tr)
vt0b,vt1b = TwoD_vt(kb,Tr)
kb = kb/(np.pi*2/a)

#===============================================================
#=== model training
#===============================================================

learning_rate = 1e-3
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

bte_train(x,y,mu,eta,L,w,k,vt0,vt1,xb,yb,kb,vt0b,vt1b,Lb,Ns,Nk,Np,Nw,logL,Tr,dT,batchsize,learning_rate,epochs,path,device)

#===============================================================
#=== model testing
#===============================================================

index = 1
Nl = 4
logL = np.linspace(0,3,Nl).reshape(-1,1)
Nx = Ny = 51
N1 = N2 = 24

x,y,mu,eta,w,k = TwoD_test_mesh(Nx,N1,N2,Nk)

bte_test(x,y,mu,eta,w,k,vt0,vt1,Nx,N1*N2,Nk,Np,logL,Tr,dT,index,path,device)

#===============================================================
#=== results ploting
#===============================================================

oldcmp = cm.get_cmap('rainbow', 512)
newcmp = ListedColormap(oldcmp(np.linspace(0.1, 1, 256)))

Data = np.load(str(index)+'Square.npz')
x = Data['x']
y = Data['y']
T = Data['T']
T[T < 0] = 0

T0 = T[:,0].reshape(-1,1)
T1 = T[:,1].reshape(-1,1)
T2 = T[:,2].reshape(-1,1)
T3 = T[:,3].reshape(-1,1)

fig, axs = plt.subplots(2, 2, figsize=(12, 10))
im = axs[0, 0].contourf(x.reshape(Nx,Nx),y.reshape(Nx,Nx),T0.reshape(Nx,Nx),cmap=newcmp,levels=np.linspace(0,1,11))
im = axs[0, 1].contourf(x.reshape(Nx,Nx),y.reshape(Nx,Nx),T2.reshape(Nx,Nx),cmap=newcmp,levels=np.linspace(0,1,11))
im = axs[1, 0].contourf(x.reshape(Nx,Nx),y.reshape(Nx,Nx),T1.reshape(Nx,Nx),cmap=newcmp,levels=np.linspace(0,1,11))
im = axs[1, 1].contourf(x.reshape(Nx,Nx),y.reshape(Nx,Nx),T3.reshape(Nx,Nx),cmap=newcmp,levels=np.linspace(0,1,11))
for ax in axs.flat:
ax.set_aspect('equal', adjustable='box')
ax.set(xlabel=r'$X$', ylabel=r'$Y$')
fig.subplots_adjust(top=0.92, bottom=0.08, left=0.10, right=0.95, hspace=0.2, wspace=0.3)
fig.colorbar(im, ax=axs.ravel().tolist(), shrink=0.6)
plt.savefig('T_square.png', dpi=400, bbox_inches='tight')