ここでは Wasserstein 計量を最適輸送コストから計算し、この計量が距離の定義(補足 1)を満たすことを証明します。 また、Wasserstein 計量が元の空間上の確率分布間の距離とみなせることを示します。

本記事は Computational Optimal Transport の命題 2.2 の解説です。 簡単のため 1-Wasserstein の場合のみ証明します。

Monge-Kantorovich の問題(COT 2.3 節)

複数の工場 $A_1 ... A_n$ で生産した量 $a_1 ... a_n$ の物資があり、これを別の場所 $B_1 ... B_m$ で利用することにしたい。 物資は使い切るものとし、$B_j$ には $b_j$ の物資を割り当てる。 したがって、$\sum_i a_i = \sum_j b_j$ が成り立つ。

$A_i$ から $B_j$ へ物資を輸送するにあたって、単位量あたり $C_{ij}$ のコストがかかるものとする。 $A_i$ から $B_j$ へ輸送する物資の量を $P_{ij}$ とすると、輸送にかかるコストは $\langle \mathbf{C}, \mathbf{P} \rangle = \sum_{ij} C_{ij} P_{ij}$ である。 ここで、輸送コストを最小にする輸送方法を最適輸送と呼び、最適な輸送 $\mathbf{P}^* $ を求める問題を最適輸送問題という。 問題設定から、$a_i = \sum_j P_{ij}$, $b_j = \sum_i P_{ij}$ が成り立つ。

以下では、有効な輸送の全体を $\mathbf{U}(\mathbf{a}, \mathbf{b})$, 最適輸送コストを $L_\mathbf{C}(\mathbf{a}, \mathbf{b})$ であらわす。
$$ \mathbf{U}(\mathbf{a}, \mathbf{b}) = \left\{ \mathbf{P} \in \mathbb{R}_+^{n \times m} \middle| \mathbf{P}\mathbf{1}_m = \mathbf{a}, \mathbf{P}^\mathrm{T}\mathbf{1}_n = \mathbf{b} \right\} $$ $$ L_\mathbf{C}(\mathbf{a}, \mathbf{b}) = \min_{\mathbf{P} \in \mathbf{U}(\mathbf{a}, \mathbf{b})} \langle \mathbf{C}, \mathbf{P} \rangle $$

最適輸送 $\mathbf{P}^* $ は線型計画法で計算できる(COT 2–3 章)。

Wasserstein 距離(COT 2.4 節)

次に、物資の生産と消費が同じ場所で行われるケースを考える($[ A_1 ... A_n ] = [ B_1 ... B_n ]$)。 また、地点間の輸送コストは距離の定義を満たすものとする(補足 1)。 輸送コスト行列 $\mathbf{D}$ はこの距離関数を行列の形で書き下したものである($D_{ij} = d(A_i, A_j)$)。

命題 2.2 (COT)

資源量を正規化し、生産量 $\mathbf{a}$ と消費量 $\mathbf{b}$ を単体 $\Sigma_n$ 上に制限する($\mathbf{a}, \mathbf{b} \in \Sigma_n$)。 このとき $\Sigma_n$ 上に定義される $p$-Wasserstein 関数
$$ W_p(\mathbf{a}, \mathbf{b}) = \left[ L_{\mathbf{D}^p}(\mathbf{a}, \mathbf{b}) \right]^{\frac{1}{p}} $$
は $\Sigma_n$ 上の距離関数になる($\Sigma_n$ は要素の和が 1 となるような非負値ベクトル全体の集合である)。

諸注意

輸送コストは有限の地点間に定義されるが、$p$-Wasserstein は単体上の距離になる。 $\Sigma_n$ はカテゴリー分布(補足 2)のパラメータ空間と同等であるため、$p$-Wasserstein 関数は確率分布間の距離とみなすことができる。

Wasserstein GAN (Arjovski et al., 2017) で使われているのは 1-Wasserstein である。 この距離関数は Earth Mover’s Distance(EM 距離)とも呼ばれている。

証明 (p = 1 の場合)

  1. $W(\mathbf{a}, \mathbf{a}) = 0$

$$ \begin{align*} W(\mathbf{a}, \mathbf{a}) & = \min \langle \mathbf{D}, \mathbf{P} \rangle \\ & \leq \langle \mathbf{D}, \mathrm{diag} (\mathbf{a}) \rangle & (\because \mathrm{diag} (\mathbf{a}) \in \mathbf{U}(\mathbf{a}, \mathbf{a}))\\ & = \sum_i D_{ii} a_i \\ & = 0 \end{align*} $$
Wasserstein 距離は最適な輸送方法を元に計算されるため、$\mathbf{U}(\mathbf{a}, \mathbf{b})$ に含まれる任意の輸送 $\mathbf{P}$ について $W(\mathbf{a}, \mathbf{b}) \leq \langle \mathbf{D}, \mathbf{P} \rangle$ が成り立つことに注意する。

  1. $W(\mathbf{a}, \mathbf{b}) > 0$ ($\mathbf{a} \neq \mathbf{b}$)

$\mathbf{a} \neq \mathbf{b}$ であるため、ある $k$ が存在し $a_k > b_k$ となる。 任意の有効な輸送 $\mathbf{P}$ について $P_{kk} \leq b_k$, $\sum_j P_{kj} = a_k$ だから、
$$ \begin{align*} \sum_{j \neq k} P_{kj} & = a_k - P_{kk} \\ & \geq a_k - b_k \end{align*} $$
が成り立つ。さらに $D^* = \min_{j \neq k}D_{kj}$ とすると、
$$ \begin{align*} \langle \mathbf{D}, \mathbf{P} \rangle & \geq \sum_{j \neq k} D_{kj} P_{kj} \\ & \geq D^* \sum_{j \neq k} P_{kj} \\ & \geq D^* (a_k - b_k) > 0 \end{align*} $$
が成立する。 $D^* (a_k - b_k)$ は $\mathbf{P}$ によらない定数なので、$W(\mathbf{a}, \mathbf{b}) > 0$ が成り立つ。

  1. $W(\mathbf{a}, \mathbf{b}) = W(\mathbf{b}, \mathbf{a})$

$W(\mathbf{a}, \mathbf{b})$, $W(\mathbf{b}, \mathbf{a})$ をあたえる最適輸送をそれぞれ $\mathbf{P}^* $, $\mathbf{Q}^* $ とする。このとき
$$ \begin{align*} W(\mathbf{a}, \mathbf{b}) & = \langle \mathbf{D}, \mathbf{P}^* \rangle \\ & \leq \langle \mathbf{D}, (\mathbf{Q}^* )^\mathrm{T} \rangle & (\because (\mathbf{Q}^* )^\mathrm{T} \in \mathbf{U}(\mathbf{a}, \mathbf{b})) \\ & = \langle \mathbf{D}^\mathrm{T}, \mathbf{Q}^* \rangle \\ & = \langle \mathbf{D}, \mathbf{Q}^* \rangle & (\because \mathbf{D} = \mathbf{D}^\mathrm{T}) \\ & = W(\mathbf{b}, \mathbf{a}) \end{align*} $$
逆も同等のため、$W(\mathbf{a}, \mathbf{b}) = W(\mathbf{b}, \mathbf{a})$ が成り立つ。

  1. $W(\mathbf{a}, \mathbf{c}) \leq W(\mathbf{a}, \mathbf{b}) + W(\mathbf{b}, \mathbf{c})$

この部分の証明は Computational Optimal Transport に詳しく書かれているため省きます。

補足 1. 距離の定義

距離関数とは、集合 $S$ 上で定義された実数値二値関数 $d$ で、次の条件を満たすものである。

  1. $d(a, a) = 0$

  2. $d(a, b) > 0$ ($a \neq b$)

  3. $d(a, b) = d(b, a)$

  4. $d(a, c) \leq d(a, b) + d(b, c)$

補足 2. カテゴリー分布

試行回数が 1 の多項分布を特別にカテゴリー分布と呼ぶことがある。
$$ \mathrm{Cat}(\mathbf{x}|\mathbf{p}) = \prod_{i = 1}^n p_i^{x_i} $$
$\mathbf{x}$ は 1 つの要素だけが 1 となり、他の要素がすべて 0 となるようなベクトルである。 これを 1-hot ベクトルという。

補足 3. 対角行列

$\mathrm{diag}$ はベクトルから対角行列を作る関数で、次のように定義される。
$$ \mathrm{diag} (\mathbf{a}) = \left[ \begin{array}{rrr} a_1 & & \\ & \ddots & \\ & & a_n \end{array} \right] $$

参考文献