St_Hakky’s blog

Data Science / Human Resources / Web Applicationについて書きます

【Python】Pythonの高速化のためのNumbaのTips

こんにちは。

最近、仕事でNumbaを使ったコードと対峙しまして、良い機会だと思って、NumbaのTipsをまとめてみました

Numbaとは

PythonやNumpyのコードを高速な機械語に変換するためのJITコンパイラのことを指します。PythonをNumbaが機械語に変換できるように書く必要はありますが、デコレーターを関数につけるだけで実行時に高速な機械語に変換して、実行してくれるようになります。

手軽に高速化ができるため、Pythonで行列の演算などを高速にしたい場合に有効な感じです。

私とかは、よくPandasを使うのですが、Pandasなどで処理をするのが重い時とかに、numpyにしてから、Numbaとかで高速に処理することを試みるみたいな感じのことをやっています。お手軽な感じに高速化できるので、便利です。

Numbaを手っ取り早く学びたい人

どうでもいいからNumbaの使い方を知りたい方は、まずは公式チュートリアルの5minguideを読むのが良きかなと思います。

Install

condaでインストールするには、以下の通り。

$ conda install numba

pipでインストールする場合は以下の通り。

$ pip install numba

そのほかのインストール方法は以下の公式サイトを参照あれ。

Numbaの効果

簡単なコードを使って、Numbaの効果をみてみたいと思います。

コード

今回は、公式のチュートリアルを参考に、具体例をベースにやってみます。

import time
import numpy as np
from numba import jit
from numba import prange


@jit(nopython=True) # Set "nopython" mode for best performance, equivalent to @njit
def fast(a): # Function is compiled to machine code when called the first time
    trace = 0.0
    for i in prange(a.shape[0]):   # Numba likes loops
        trace += np.tanh(a[i, i]) # Numba likes NumPy functions
    return a + trace              # Numba likes NumPy broadcasting


def slow(a):
    trace = 0.0
    for i in range(a.shape[0]):
        trace += np.tanh(a[i, i])
    return a + trace


if __name__ == '__main__':
    x = np.arange(10000).reshape(100, 100)

    # fast with compile
    start_t = time.time()
    fast(x)
    end_t = time.time()
    time_fast_with_compile = end_t - start_t
    print('fast with compile :', time_fast_with_compile)

    # fast without compile
    start_t = time.time()
    fast(x)
    end_t = time.time()
    time_fast_without_compile = end_t - start_t
    
    print('fast without compile :', time_fast_without_compile)

    # slow
    start_t = time.time()
    slow(x)
    end_t = time.time()
    time_slow = end_t - start_t
    
    print('slow : ', time_slow)
    
    print('-' * 10)
    print('(fast with compile) / (slow) : ', time_fast_with_compile / time_slow)
    print('(fast without compile) / (slow) : ', time_fast_without_compile / time_slow)
実行結果

上のコードを実行すると以下のような結果が得られます。

fast with compile : 0.11895394325256348
fast without compile : 3.886222839355469e-05
slow :  0.000209808349609375
----------
(fast with compile) / (slow) :  566.9647727272727
(fast without compile) / (slow) :  0.18522727272727274
結果

Numbaには、冒頭の説明にあった通り、実行時にコンパイルをする必要があります。そのため、コンパイルをする時間も含まれる一番最初の実行は、slowの関数の実行よりも遅くなります。

2回目の実行の際には、cacheが効くため、実行時間が速くなります。コンパイルのオーバーヘッドが問題になる場合には、 @jitのデコレーターのパラメーターに、 cacheというパラメーターに Trueを指定してあげればオーケーです。

公式のドキュメントには、以下のように説明されています。

If true, cache enables a file-based cache to shorten compilation times when the function was already compiled in a previous invocation. The cache is maintained in the __pycache__ subdirectory of the directory containing the source file; if the current user is not allowed to write to it, though, it falls back to a platform-specific user-wide cache directory (such as $HOME/.cache/numba on Unix platforms).

これにより、2回目の実行からはcacheから呼び出すため、速くなります。

Numbaの使い方

nopython

Numbaには、 nopython modeobject modeがあります。nopython modeは高速な機械語にコンパイルすることができた状態なのですが、もしコンパイルできなかった際には、object modeにfall backされます。

@jit(nopython=True)
def f(x, y):
    return x + y
型指定

Numbaを使うときには、型を指定せずに使うと、Numbaの方で勝手に推論してくれるのですが、以下のように指定することもできます。

from numba import jit, int32

@jit(int32(int32, int32))
def f(x, y):
    # A somewhat trivial example
    return x + y

以下のサイトに指定できる型もあります。

cache

コンパイルを何回もするのを避けるために、キャッシュを使うためには、以下のような感じでパラメーターを設定します。

@jit(cache=True)
def f(x, y):
    return x + y
parallel

並列で実行するオプションは、以下のような感じでパラメーターを設定することでできます。

@jit(nopython=True, parallel=True)
def f(x, y):
    return x + y

Numbaの注意点

Pandasなど、全てのPythonをコンパイルして高速化することはできません。できないときには、object modeにfall backされるのですが、そうすると逆に遅くなります。

詳しい対応しているfeatureについては、以下の公式ドキュメントが参考になります*1

そのほかの注意点は、以下の参考の記事にもあるので、興味がある方は要チェック。sumとかが意外に使えなかったりします笑。

より速くしたいとき

以下あたりですかね。

  • object modeになっていないところがないか確認する
  • 型をめんどくさがらずに指定する (こちらの記事によると、型指定よりさらに10倍高速化とのこと。コンパイルのオーバーヘッドも少なくすることができる)
  • fastmath使えるところは使っていく
  • rangeのところを prangeにする+parallel=True使っていく
  • 何回も使う関数は@jit(cache=True)

その他参考としては、Numbaの公式の Performance Tipsが参考になります。

Performance Tips — Numba 0.49.1-py3.6-macosx-10.7-x86_64.egg documentation

*1:覚えてられないので、都度コンパイルしながらでいいと思いますが