こんにちは。
データサイエンティストの卵 ”ひよこ” です。駆け出しエンジニアの”くるる”です。

”はやぶさ先生”@Cpp_Learning から“プログラミング”や“機械学習”など、色んなことを教わりながら、成長中です♪
今回はPyTorch, TensorFlow, JAXの違いを吸収するコードの書き方を紹介します。
Contents
実践する前に
最初に以下のコードを実行してね
以降で登場するコードは、事前に上記のコードを実行している前提で話を進めます。
深層学習フレームワークのPyTorchで自動微分
くるる先輩~。以下の関数を微分したいです
PyTorch 使えば、以下のコードで簡単に自動微分できるよ
使い方は以下の通りです。
x = 2.0
f(x) = 16.0
df(x)/dx = tensor(24.)
不明点があれば公式チュートリアルをチェック
NumPy専用の自動微分
さすが”くるる”先輩!でも numpy.ndarray を入力したいです
それなら以下のコードの方が良いね
簡単なフローは以下の通りです。
- numpy.ndarrayをtorch.Tensorに変換
- 勾配の計算(微分)
- torch.Tensorをnumpy.ndarrayに変換(元に戻す)
使い方は先ほどと、ほとんど同じです。
x = 2.0
f(x) = 16.0
df(x)/dx = 24.0
ユースケース
NumPy 入力はやめました。代わりに Pytorch, TensorFlow, JAX を入力して微分したいです
以上のように、やりたい処理は同じ(今回の場合は自動微分)でも、使用するライブラリやフレームワークに合わせてコードを修正する必要があります。
採用するライブラリが”カッチリ”決まっており、かつ今後も変更予定がないなら、一つの専用関数を作れば良いのですが…
”ひよこ”ちゃん みたいに後出しで修正依頼をするケースもあり、それなら最初から『あらゆる違いを吸収できるコードを書く』というユースケースがあります。
そんなユースケースにフィットするのが EagerPy です。
EagerPyとは
EagerPyとは、以下をゴールに設計されたPythonフレームワークです。
Design goals
- Native Performance: EagerPy operations get directly translated into the corresponding native operations.
- Fully Chainable: All functionality is available as methods on the tensor objects and as EagerPy functions.
- Type Checking: Catch bugs before running your code thanks to EagerPy’s extensive type annotations.
引用元:eagerpy|GitHub
実際に使ってみると EagerPy の良さを実感できます。
実践!EagerPyでPyTorch, TensorFlow, JAXの違いを吸収
最初に以下のコマンドで EagerPy をインストールします。
pip install eagerpy
以降から各ライブラリの違いを吸収するコードを書いていきます。
EagerPyの基本的な使い方
以下のように x を EagerPy でラップすることで、PyTorch, TensorFlow, JAXの違いを吸収し、全ての x に対して同じ API が使えるようになります。
API が充実しており、自動微分もサポートしています。
PyTorch, TensorFlow, JAXの違いを吸収して自動微分する
ep.value_and_grad を使えば関数からの出力と微分の結果を同時に得ることができます。また .raw を使うことで、EagerPy tensor を元の native tensor に戻せます。
=============== EagerPy ===============
x = PyTorchTensor(tensor(2.))
f(x) = PyTorchTensor(tensor(16.))
df(x)/dx = PyTorchTensor(tensor(24.))
======= PyTorch, TensorFlow, JAX ========
x = tensor(2.)
f(x) = tensor(16.)
df(x)/dx = tensor(24.)
EagerPy とても便利ですね。
【おまけ】ライブラリ未使用でフォーマット違いを吸収して自動微分
EagerPy 使わないと、ライブラリやフレームワークの違いを吸収できないの?
例えば、以下のコードでも入力値のフォーマット違いを吸収して微分できます。
使い方は以下の通りです。
========= PyTorch ==========
x = tensor(2.)
f(x) = tensor(16.)
df(x)/dx = tensor(24.)
======== TensorFlow ========
x = tf.Tensor(2.0, shape=(), dtype=float32)
f(x) = tf.Tensor(16.0, shape=(), dtype=float32)
df(x)/dx = tf.Tensor(24.0, shape=(), dtype=float32)
=========== JAX ===========
x = 2.0
f(x) = 16.0
df(x)/dx = 24.0
上記のような簡単な処理で良いなら、自作しても良いかもしれません。
『どんな処理を自作すべきか?どの範囲までライブラリに頼るか?』については、あらゆる制約を考慮した上で選択すれば良いと思います。
おわりに -PyTorch, TensorFlow, JAXの違いを吸収するコードの書き方-
PyTorch, TensorFlow, JAXの違いを吸収するコードの書き方を紹介しました。
EagerPy を活用するも良し、吸収するコードを自作するも良しです。あらゆる制約を考慮しながら、選択していけば良いです。
自身の知識やスキルを増やすことで、選択肢(武器)を増やすことができます。課題に応じてベストな選択ができるように、日々努力できると良いですね。
そもそも微分のコードの書き方を知りませんでした。とても勉強になりました