%config InlineBackend.figure_format = 'retina'
import numpy as np
import matplotlib.pyplot as plt
import math as mt
from numba.decorators import jit
from numba import f8,u8
import scipy.fftpack as sf
from IPython.display import display, clear_output
なぜかNumbaのかけ方で計算時間が全然違う。
なんでだろう?
下のサンプルプログラムはESのPICコード。
@jit
def acce(vt,nop,dx,nx,dt,nt):
np.seterr(divide='ignore', invalid='ignore')
lx=dx*nx
xx=dx*np.arange(nx)
xp=np.linspace(0,lx-lx/nop,nop)
vp=np.random.normal(0,vt,nop)#randn(nop)
kk=2*np.pi/lx*np.r_[np.arange(nx/2),np.arange(-nx/2,0)]
for it in range(nt):
xp=xp+dt*vp
#
xp[xp>lx]=xp[xp>lx]-lx
xp[xp<0.0 ]=xp[xp<0.0 ]+lx
ds=np.zeros(nx)
for ip in range(nop):
ixm=mt.floor(xp[ip]/dx); ixp=ixm+1
wxp=xp[ip]/dx-ixm; wxm=1-wxp
if ixp>nx-1: ixp=ixp-nx#; print(ixp)
ds[ixm]=ds[ixm]+wxm
ds[ixp]=ds[ixp]+wxp
ds=ds/nop*nx
exfft=1j/kk*sf.fft(ds)
exfft[0]=0
ex=np.real(sf.ifft(exfft))
#
for ip in range(nop):
ixm=mt.floor(xp[ip]/dx); ixp=ixm+1
wxp=xp[ip]/dx-ixm; wxm=1-wxp
if ixp>nx-1: ixp=ixp-nx#; print(ixp)
vp[ip]=vp[ip]-dt*(wxm*ex[ixm]+wxp*ex[ixp])
#
# plt.subplot(3,1,1); plt.plot(xp,vp,'.')
# plt.subplot(3,1,2); plt.plot(xx,ds,'-k')
# plt.subplot(3,1,3); plt.plot(xx,ex,'-k')
# plt.show()
%timeit acce(1.0,10**5,1.0,2**8,1.0,256)
functionを分けて作って、Numbaをそれぞれにかませる。 速度が爆速になる。
@jit('(f8,u8,f8,u8,f8,u8)')
def acce_numba(vt,nop,dx,nx,dt,nt):
np.seterr(divide='ignore', invalid='ignore')
lx=dx*nx
xx=dx*np.arange(nx)
xp=np.linspace(0,lx-lx/nop,nop)
vp=np.random.normal(0,vt,nop)#randn(nop)
ds=np.zeros(nx)
kk=2*mt.pi/lx*np.r_[np.arange(nx/2),np.arange(-nx/2,0)]
esave=np.zeros((nt,nx))
for it in range(nt):
push(nop,xp,vp,lx,dt)
dens(dx,nx,nop,xp,ds)
exfft=1j/kk*sf.fft(ds/nop*nx)
exfft[0]=0
ex=np.real(sf.ifft(exfft))
acc(dx,nx,nop,xp,vp,ex,dt)
esave[it,:]=ex[:]
@jit('f8[:](u8,f8[:],f8[:],f8,f8)')
def push(nop,xp,vp,lx,dt):
for ip in range(nop):
xp[ip]=xp[ip]+dt*vp[ip]
if xp[ip]>lx: xp[ip]=xp[ip]-lx
if xp[ip]<0: xp[ip]=xp[ip]+lx
return xp
@jit('f8[:](f8,u8,u8,f8[:],f8[:],f8[:],f8)')
def acc(dx,nx,nop,xp,vp,ex,dt):
for ip in range(nop):
ixm=mt.floor(xp[ip]/dx); ixp=ixm+1
wxp=xp[ip]/dx-ixm; wxm=1-wxp
if ixp>nx-1: ixp=ixp-nx#; print(ixp)
vp[ip]=vp[ip]-dt*(wxm*ex[ixm]+wxp*ex[ixp])
return vp
@jit('f8[:](f8,u8,u8,f8[:],f8[:])')
def dens(dx,nx,nop,xp,ds):
ds=np.zeros(nx)
for ip in range(nop):
ixm=mt.floor(xp[ip]/dx); ixp=ixm+1
wxp=xp[ip]/dx-ixm; wxm=1-wxp
if ixp>nx-1: ixp=ixp-nx#; print(ixp)
ds[ixm]=ds[ixm]+wxm
ds[ixp]=ds[ixp]+wxp
return ds
%timeit acce_numba(1.0,10**5,1.0,2**8,1.0,256)
計算時間のパラメータ依存性を調べる。
変えるパラメータは粒子数(nop)、空間グリッド(nx)、空間ステップ(nt)。
import time
npr=50
elapsed_time=np.zeros(npr)
#changing the number of particles
nop=np.linspace(10,10**6,npr).astype(int)
for iop in range(npr):
start = time.time()
acce_numba(1.0,nop[iop],1.0,2**8,1.0,2**8)
elapsed_time[iop] = time.time() - start
#print(iop,elapsed_time[iop])
plt.plot(nop,elapsed_time,'-o');plt.show()
import time
npr=50
elapsed_time=np.zeros(npr)
nt=np.linspace(10,10**4,npr).astype(int)
for ipr in range(npr):
start = time.time()
acce_numba(1.0,10**4,1.0,2**8,1.0,nt[ipr])
elapsed_time[ipr] = time.time() - start
#print(ipr,elapsed_time[ipr])
plt.plot(nt,elapsed_time,'-o');plt.show()
import time
npr=10
elapsed_time=np.zeros(npr)
nx=2**(np.arange(npr)+3)
#print(nx)
for ipr in range(npr):
start = time.time()
acce_numba(1.0,10**4,1.0,nx[ipr],1.0,1024)
elapsed_time[ipr] = time.time() - start
#print(ipr,elapsed_time[ipr])
plt.plot(nx,elapsed_time,'-o');plt.show()