こんにちは。
データサイエンティストの卵 ”ひよこ” です。駆け出しエンジニアの”くるる”です。
”はやぶさ先生”@Cpp_Learning から“プログラミング”や“機械学習”など、色んなことを教わりながら、成長中です♪
今回はPyTorch, TensorFlow, JAXの違いを吸収するコードの書き方を紹介します。
Contents
実践する前に
最初に以下のコードを実行してね🐣🦉
1 2 3 4 |
import torch import tensorflow as tf import jax.numpy as jnp import numpy as np |
以降で登場するコードは、事前に上記のコードを実行している前提で話を進めます。
深層学習フレームワークのPyTorchで自動微分
くるる先輩~。以下の関数を微分したいです🐣
1 2 3 |
def my_func(x): ''' f(x) = 2x^3 ''' return 2 * x**3 |
PyTorch 使えば、以下のコードで簡単に自動微分できるよ🦉
1 2 3 4 5 6 7 8 9 10 11 12 |
def dx_func(x): ''' Automatic Differentiation with PyTorch ''' # xをtorch.tensorに変換して、微分対象に指定 x = torch.tensor(x, requires_grad=True) # 関数:y = 2x^3 y = my_func(x) # 勾配を計算 y.backward() return x.grad |
使い方は以下の通りです。
1 2 3 4 |
x = 2.0 print("x =", x) print("f(x) =", my_func(x)) print("df(x)/dx =", dx_func(x)) |
x = 2.0
f(x) = 16.0
df(x)/dx = tensor(24.)
不明点があれば公式チュートリアルをチェック🦉
NumPy専用の自動微分
さすが”くるる”先輩!でも numpy.ndarray を入力したいです🐣
それなら以下のコードの方が良いね🦉
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 |
def dx_func_numpy(x): ''' Automatic Differentiation for Numpy ''' # numpy to torch x = torch.from_numpy(x) # xを微分対象に指定 x.requires_grad_(True) # 関数:y = 2x^3 y = my_func(x) # 勾配を計算 y.backward() return x.grad.to('cpu').detach().numpy().copy() |
簡単なフローは以下の通りです。
- numpy.ndarrayをtorch.Tensorに変換
- 勾配の計算(微分)
- torch.Tensorをnumpy.ndarrayに変換(元に戻す)
使い方は先ほどと、ほとんど同じです。
1 2 3 4 |
x = np.array(2.0) print("x =", x) print("f(x) =", my_func(x)) print("df(x)/dx =", dx_func_numpy(x)) |
x = 2.0
f(x) = 16.0
df(x)/dx = 24.0
ユースケース
NumPy 入力はやめました。代わりに Pytorch, TensorFlow, JAX を入力して微分したいです🐣
以上のように、やりたい処理は同じ(今回の場合は自動微分)でも、使用するライブラリやフレームワークに合わせてコードを修正する必要があります。
採用するライブラリが”カッチリ”決まっており、かつ今後も変更予定がないなら、一つの専用関数を作れば良いのですが…
”ひよこ”ちゃん みたいに後出しで修正依頼をするケースもあり、それなら最初から『あらゆる違いを吸収できるコードを書く』というユースケースがあります。
そんなユースケースにフィットするのが EagerPy です。
EagerPyとは
EagerPyとは、以下をゴールに設計されたPythonフレームワークです。
Design goals
- Native Performance: EagerPy operations get directly translated into the corresponding native operations.
- Fully Chainable: All functionality is available as methods on the tensor objects and as EagerPy functions.
- Type Checking: Catch bugs before running your code thanks to EagerPy’s extensive type annotations.
引用元:eagerpy|GitHub
実際に使ってみると EagerPy の良さを実感できます。
実践!EagerPyでPyTorch, TensorFlow, JAXの違いを吸収
最初に以下のコマンドで EagerPy をインストールします。
pip install eagerpy
以降から各ライブラリの違いを吸収するコードを書いていきます。
EagerPyの基本的な使い方
以下のように x を EagerPy でラップすることで、PyTorch, TensorFlow, JAXの違いを吸収し、全ての x に対して同じ API が使えるようになります。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 |
import torch import tensorflow as tf import jax.numpy as jnp import numpy as np import eagerpy as ep # No matter which framwork you use, you can use the same code x = torch.tensor([1., 2., 3., 4., 5., 6.]) x = tf.constant([1., 2., 3., 4., 5., 6.]) x = jnp.array([1., 2., 3., 4., 5., 6.]) x = np.array([1., 2., 3., 4., 5., 6.]) # Just wrap a native tensor using EagerPy x = ep.astensor(x) # All of EagerPy's functionality is available as methods x = x.reshape((2, 3)) |
API が充実しており、自動微分もサポートしています。
PyTorch, TensorFlow, JAXの違いを吸収して自動微分する
ep.value_and_grad を使えば関数からの出力と微分の結果を同時に得ることができます。また .raw を使うことで、EagerPy tensor を元の native tensor に戻せます。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 |
# No matter which framwork you use, you can use the same code x = jnp.array(2.0) x = tf.constant(2.0) x = torch.tensor(2.0) # torch to eagerpy xe = ep.astensor(x) # Automatic Differentiation with EagerPy output, grad = ep.value_and_grad(my_func, xe) print("================ EagerPy ================") print("x =", xe) print("f(x) =", output) print("df(x)/dx =", grad) # eagerpy to PyTorch, TensorFlow or JAX print("======= PyTorch, TensorFlow, JAX ========") print("x =", xe.raw) print("f(x) =", output.raw) print("df(x)/dx =", grad.raw) |
=============== EagerPy ===============
x = PyTorchTensor(tensor(2.))
f(x) = PyTorchTensor(tensor(16.))
df(x)/dx = PyTorchTensor(tensor(24.))
======= PyTorch, TensorFlow, JAX ========
x = tensor(2.)
f(x) = tensor(16.)
df(x)/dx = tensor(24.)
EagerPy とても便利ですね。
【おまけ】ライブラリ未使用でフォーマット違いを吸収して自動微分
EagerPy 使わないと、ライブラリやフレームワークの違いを吸収できないの?🐣🦉
例えば、以下のコードでも入力値のフォーマット違いを吸収して微分できます。
1 2 3 4 5 6 7 8 9 10 11 |
def dx_func_all(x): ''' dy/dx ''' delta = 1e-5 x1 = x x2 = x + delta # 関数:y = 2x^3 y1 = my_func(x1) y2 = my_func(x2) return (y2-y1)/(x2-x1) |
使い方は以下の通りです。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 |
x = torch.tensor(2.0) print("========= PyTorch ==========") print("x =", x) print("f(x) =", my_func(x)) print("df(x)/dx =", dx_func_all(x)) x = tf.constant(2.0) print("======== TensorFlow ========") print("x =", x) print("f(x) =", my_func(x)) print("df(x)/dx =", dx_func_all(x)) x = jnp.array(2.0) print("=========== JAX ============") print("x =", x) print("f(x) =", my_func(x)) print("df(x)/dx =", dx_func_all(x)) |
========= PyTorch ==========
x = tensor(2.)
f(x) = tensor(16.)
df(x)/dx = tensor(24.)
======== TensorFlow ========
x = tf.Tensor(2.0, shape=(), dtype=float32)
f(x) = tf.Tensor(16.0, shape=(), dtype=float32)
df(x)/dx = tf.Tensor(24.0, shape=(), dtype=float32)
=========== JAX ===========
x = 2.0
f(x) = 16.0
df(x)/dx = 24.0
上記のような簡単な処理で良いなら、自作しても良いかもしれません。
『どんな処理を自作すべきか?どの範囲までライブラリに頼るか?』については、あらゆる制約を考慮した上で選択すれば良いと思います。
おわりに -PyTorch, TensorFlow, JAXの違いを吸収するコードの書き方-
PyTorch, TensorFlow, JAXの違いを吸収するコードの書き方を紹介しました。
EagerPy を活用するも良し、吸収するコードを自作するも良しです。あらゆる制約を考慮しながら、選択していけば良いです。
自身の知識やスキルを増やすことで、選択肢(武器)を増やすことができます。課題に応じてベストな選択ができるように、日々努力できると良いですね。
そもそも微分のコードの書き方を知りませんでした。とても勉強になりました🐣