Python

【EagerPy】PyTorch, TensorFlow, JAXの違いを吸収して自動微分する

EagerPy

こんにちは。

データサイエンティストの卵 ”ひよこ” です。駆け出しエンジニアの”くるる”です。

くるる

”はやぶさ先生”@Cpp_Learning から“プログラミング”機械学習”など、色んなことを教わりながら、成長中です♪

今回はPyTorch, TensorFlow, JAXの違いを吸収するコードの書き方を紹介します。

実践する前に

最初に以下のコードを実行してね🐣🦉

以降で登場するコードは、事前に上記のコードを実行している前提で話を進めます。

深層学習フレームワークのPyTorchで自動微分

くるる先輩~。以下の関数を微分したいです🐣

PyTorch 使えば、以下のコードで簡単に自動微分できるよ🦉

使い方は以下の通りです。

x = 2.0
f(x) = 16.0
df(x)/dx = tensor(24.)

不明点があれば公式チュートリアルをチェック🦉

スポンサーリンク

NumPy専用の自動微分

さすが”くるる”先輩!でも numpy.ndarray を入力したいです🐣

それなら以下のコードの方が良いね🦉

簡単なフローは以下の通りです。

  1. numpy.ndarraytorch.Tensorに変換
  2. 勾配の計算(微分)
  3. torch.Tensornumpy.ndarrayに変換(元に戻す)

使い方は先ほどと、ほとんど同じです。

x = 2.0
f(x) = 16.0
df(x)/dx = 24.0

ユースケース

NumPy 入力はやめました。代わりに Pytorch, TensorFlow, JAX を入力して微分したいです🐣

くるる
くるる
助けて!はやぶさ先生~
はやぶさ
はやぶさ
任せて!ここから交代します

以上のように、やりたい処理は同じ(今回の場合は自動微分)でも、使用するライブラリやフレームワークに合わせてコードを修正する必要があります。

採用するライブラリが”カッチリ”決まっており、かつ今後も変更予定がないなら、一つの専用関数を作れば良いのですが…

”ひよこ”ちゃん みたいに後出しで修正依頼をするケースもあり、それなら最初から『あらゆる違いを吸収できるコードを書く』というユースケースがあります。

そんなユースケースにフィットするのが EagerPy です。

EagerPyとは

EagerPyとは、以下をゴールに設計されたPythonフレームワークです。

Design goals

  • Native Performance: EagerPy operations get directly translated into the corresponding native operations.
  • Fully Chainable: All functionality is available as methods on the tensor objects and as EagerPy functions.
  • Type Checking: Catch bugs before running your code thanks to EagerPy’s extensive type annotations.

引用元:eagerpy|GitHub

実際に使ってみると EagerPy の良さを実感できます。

実践!EagerPyでPyTorch, TensorFlow, JAXの違いを吸収

最初に以下のコマンドで EagerPy をインストールします。

pip install eagerpy

以降から各ライブラリの違いを吸収するコードを書いていきます。

EagerPyの基本的な使い方

以下のように x を EagerPy でラップすることで、PyTorch, TensorFlow, JAXの違いを吸収し、全ての x に対して同じ API が使えるようになります。

API が充実しており、自動微分もサポートしています。

PyTorch, TensorFlow, JAXの違いを吸収して自動微分する

ep.value_and_grad を使えば関数からの出力と微分の結果を同時に得ることができます。また .raw を使うことで、EagerPy tensor を元の native tensor に戻せます

=============== EagerPy ===============
x = PyTorchTensor(tensor(2.))
f(x) = PyTorchTensor(tensor(16.))
df(x)/dx = PyTorchTensor(tensor(24.))
======= PyTorch, TensorFlow, JAX ========
x = tensor(2.)
f(x) = tensor(16.)
df(x)/dx = tensor(24.)

EagerPy とても便利ですね。

スポンサーリンク

【おまけ】ライブラリ未使用でフォーマット違いを吸収して自動微分

EagerPy 使わないと、ライブラリやフレームワークの違いを吸収できないの?🐣🦉

はやぶさ
はやぶさ
もちろん”ゼロ”から自作すれば可能ですよ

例えば、以下のコードでも入力値のフォーマット違いを吸収して微分できます。

使い方は以下の通りです。

========= PyTorch ==========
x = tensor(2.)
f(x) = tensor(16.)
df(x)/dx = tensor(24.)
======== TensorFlow ========
x = tf.Tensor(2.0, shape=(), dtype=float32)
f(x) = tf.Tensor(16.0, shape=(), dtype=float32)
df(x)/dx = tf.Tensor(24.0, shape=(), dtype=float32)
=========== JAX ===========
x = 2.0
f(x) = 16.0
df(x)/dx = 24.0

上記のような簡単な処理で良いなら、自作しても良いかもしれません。

『どんな処理を自作すべきか?どの範囲までライブラリに頼るか?』については、あらゆる制約を考慮した上で選択すれば良いと思います。

おわりに -PyTorch, TensorFlow, JAXの違いを吸収するコードの書き方-

PyTorch, TensorFlow, JAXの違いを吸収するコードの書き方を紹介しました。

EagerPy を活用するも良し、吸収するコードを自作するも良しです。あらゆる制約を考慮しながら、選択していけば良いです。

自身の知識やスキルを増やすことで、選択肢(武器)を増やすことができます。課題に応じてベストな選択ができるように、日々努力できると良いですね。

くるる
くるる
EagerPyのこと知らなかったから、PyTorchのみで自動微分しようとしてた!

そもそも微分のコードの書き方を知りませんでした。とても勉強になりました🐣

はやぶさ
はやぶさ
”ひよこ”ちゃん達みたいに、本記事が参考になれば嬉しいです。現役エンジニアの”はやぶさ”@Cpp_Learningは頑張っている人を応援します。
Amazonギフト券購入で500ポイント還元中!

Amazonギフト券を3,000円分購入すると500ポイントもらえる!

Amazon商品2,000円分とギフト券を同一カートで購入するだけ。

キャンペーン期間は1/31まで。急げ!

 

ギフト券キャンペーンをチェックする

ギフト券と同じカートでまとめ買いができる商品すべて対象

PICK UP BOOKS

  • 数理モデル入門
    数理モデル
  • Jetoson Nano 超入門
    Jetoson Nano
  • 図解速習DEEP LEARNING
    DEEP LEARNING
  • Pythonによる因果分析
    Python