Wasserstein GAN に出てくる最適輸送問題を SciPy で解いてみます。
問題の説明は前回の記事をご覧ください。
問題設定
2つの生産地 A1, A2 から3つの消費地 B1, B2, B3 へ物資を輸送する問題を考えます。生産量 $\mathbf{a}$, 消費量 $\mathbf{b}$, 単位量あたりの輸送コスト $\mathbf{C}$ を次の値に設定します。
a = [ 0.5, 0.5 ]
b = [ 0.25, 0.5, 0.25 ]
B1 B2 B3
C = [ 1, 2, 3 ] A1
[ 3, 2, 1 ] A2
最適輸送問題では輸送 $\mathbf{P}$ と輸送コスト $\mathbf{C}$ が行列の形になっていますが、線型計画問題のライブラリを使う際はこれらを一次元に並べなおす必要があります。
実装
scipy.optimize.linprog
を使うと、$A_{\mathrm{eq}}\mathbf{x} = \mathbf{b}_{\mathrm{eq}}$ の条件を満たしつつ $\langle \mathbf{c}, \mathbf{x} \rangle$ を最小化する解を計算することができます。 bounds
を指定しない場合、$\mathbf{x}$ の各要素は 0 以上と仮定されます。
import numpy
import scipy.optimize
c = numpy.array([1, 2, 3, 3, 2, 1])
A = numpy.array([
[ 1, 1, 1, 0, 0, 0 ],
[ 0, 0, 0, 1, 1, 1 ],
[ 1, 0, 0, 1, 0, 0 ],
[ 0, 1, 0, 0, 1, 0 ],
[ 0, 0, 1, 0, 0, 1 ],
])
b = numpy.array([.5, .5, .25, .5, .25]) # a_1, a_2, b_1, b_2, b_3
print(scipy.optimize.linprog(c, A_eq=A, b_eq=b))
このプログラムを実行すると次の解が得られます。
$ python3 sample.py
fun: 1.5
message: 'Optimization terminated successfully.'
nit: 8
slack: array([], dtype=float64)
status: 0
success: True
x: array([0.25, 0.25, 0. , 0. , 0.25, 0.25])
したがって、最適な輸送 $\mathbf{P}^* $ は
B1 B2 B3
P = [ 0.25, 0.25, 0 ] A1
[ 0 , 0.25, 0.25 ] A2
また、この時の輸送コストは 1.5
となることが分かりました。
参考資料
-
Wasserstein 距離について
https://yokaze.github.io/2019/07/12/ -
scipy.optimize.linprog — SciPy v1.3.0 Reference Guide
https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.linprog.html -
Wasserstein GAN and the Kantorovich-Rubinstein Duality
https://vincentherrmann.github.io/blog/wasserstein/ -
Gabriel Peyré and Marco Cuturi, “Computational Optimal Transport”, Foundations and Trends in Machine Learning, vol. 11, no. 5–6, pp. 355–607, 2019.
https://arxiv.org/abs/1803.00567 -
Martin Arjovsky, Soumith Chintala, and Léon Bottou, “Wasserstein Generative Adversarial Networks”, in Proc. ICML, 2017, pp. 214–223.
http://proceedings.mlr.press/v70/arjovsky17a.html