Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions ot/lp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,15 @@
linear_circular_ot,
)

from .solver_tree import (
topological_sort,
tree_wasserstein,
)

from .tree_barycenter import (
tree_barycenter,
)

__all__ = [
"emd",
"emd2",
Expand All @@ -60,4 +69,7 @@
"free_support_barycenter_generic_costs",
"NorthWestMMGluing",
"ot_barycenter_energy",
"topological_sort",
"tree_wasserstein",
"tree_barycenter",
]
183 changes: 183 additions & 0 deletions ot/lp/solver_tree.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
from ..backend import get_backend
import numpy as np
from collections import deque

"""
Solver for the tree wasserstein distance
"""

# Author : Ali Boudjema


def topological_sort(tree):
r"""
Computes a topological order of the given tree

Parameters
-----------
tree: array_like, shape(n)
ancestor of each node in the tree (ancestor of root is root)
"""

n = tree.shape[0]

in_degree = np.zeros(n, dtype=int)

for cur_node in range(n):
if cur_node != tree[cur_node]:
in_degree[tree[cur_node]] += 1

queue = deque()

for cur_node in range(n):
if in_degree[cur_node] == 0:
queue.append(cur_node)

topo_order = []

while queue:
cur_node = queue.popleft()
topo_order.append(cur_node)

ancestor = tree[cur_node]

if cur_node != ancestor:
in_degree[ancestor] -= 1

if in_degree[ancestor] == 0:
queue.append(ancestor)

return np.array(topo_order)


def tree_wasserstein(
tree, length, u_weights, v_weights, topo_order=None, return_plans=False
):
r"""
Computes the tree wasserstein distance for a given tree between two empirical distributions

Parameters
----------
tree : array_like, shape(n)
ancestor of each node in the tree (ancestor of root is root)
length : array_like, shape(n)
length of the arc above each node (length of root is 0)
u_weights : array_like, shape(n)
weights of the first empirical distributions
v_weights : array_like, shape(n)
weights of the second empirical distributions
topo_order : array_like, shape(n), optional
topological order of the tree
return_plans : bool, optional
if True, returns the optimal transport plan between the
two distributions, default is False

Returns
-------
cost : float
The tree wasserstein distance
plans : coo_matrix, optional
If return_plans is True, returns a coo_matrix containing the plan

Reference
---------
The proof of this algorithm uses the formula (3) in the article
Tree-Sliced Variants of Wasserstein Distances
"""

n = tree.shape[0]

assert (
n == length.shape[0] == u_weights.shape[0] == v_weights.shape[0]
), "dimension error in the input"

if topo_order is None:
topo_order = topological_sort(tree)

nx = get_backend(length, u_weights, v_weights)

mass_dict = {}

for cur in range(n):
if u_weights[cur] != v_weights[cur]:
mass_dict[cur] = {cur: u_weights[cur] - v_weights[cur]}
else:
mass_dict[cur] = {}

source_plan = []
sink_plan = []
mass_plan = []

virt_size = [len(mass_dict[k]) for k in range(n)]

cost = 0

depth = nx.zeros(n)

for i in range(n - 2, -1, -1):
cur_node = topo_order[i]
depth[cur_node] = depth[tree[cur_node]] + length[cur_node]

for cur in topo_order:
dict_cur = mass_dict[cur]
p = tree[cur]

if cur != p:
dict_p = mass_dict[p]

if virt_size[cur] > virt_size[p]:
mass_dict[cur], mass_dict[p] = dict_p, dict_cur
dict_cur, dict_p = dict_p, dict_cur
virt_size[cur], virt_size[p] = virt_size[p], virt_size[cur]

while len(dict_cur) > 0 and len(dict_p) > 0:
node_scur = next(iter(dict_cur))
amount_scur = dict_cur[node_scur]

node_sp = next(iter(dict_p))
amount_sp = dict_p[node_sp]

if (amount_scur > 0) != (amount_sp > 0):
match_amount = min(abs(amount_scur), abs(amount_sp))

source = node_scur if amount_scur > 0 else node_sp
sink = node_sp if amount_scur > 0 else node_scur

source_plan.append(source)
sink_plan.append(sink)
mass_plan.append(match_amount)

length_path = depth[source] + depth[sink] - 2 * depth[p]
cost += match_amount * length_path

if amount_scur > 0:
dict_cur[node_scur] -= match_amount
dict_p[node_sp] += match_amount
else:
dict_cur[node_scur] += match_amount
dict_p[node_sp] -= match_amount

if dict_cur[node_scur] == 0:
del dict_cur[node_scur]

if dict_p[node_sp] == 0:
del dict_p[node_sp]

else:
dict_p[node_scur] = amount_scur
del dict_cur[node_scur]

if len(dict_p) == 0:
mass_dict[cur], mass_dict[p] = dict_p, dict_cur
dict_cur, dict_p = dict_p, dict_cur

virt_size[p] += virt_size[cur]

plans = nx.coo_matrix(
mass_plan, source_plan, sink_plan, shape=(n, n), type_as=length
)

if return_plans:
return cost, plans
else:
return cost
100 changes: 100 additions & 0 deletions ot/lp/tree_barycenter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
from ..backend import get_backend
import numpy as np
from .solver_tree import topological_sort

# Author : Ali Boudjema


def wgm(values, weights):
# Returns the weighted geometric median

nx = get_backend(values, weights)

sorted_indices = np.argsort(values, kind="stable")

values_sorted = values[sorted_indices]
weights_sorted = weights[sorted_indices]

cum_weights = nx.cumsum(weights_sorted)

id = nx.searchsorted(cum_weights, 0.5 - 1e9)

return values_sorted[id]


def get_measure(z, tree, length):
# Retrieves the measure from a vector after the wgm

n = z.shape[0]

nx = get_backend(length)

measure = nx.zeros(n)

for i in range(n):
p = tree[i]

if i == p:
measure[i] += 1
else:
measure[i] += z[i] / length[i]
measure[p] -= z[i] / length[i]

return measure


def tree_barycenter(tree, length, measure, weights, topo_order=None):
r"""
Computes the tree wasserstein barycenter for a given tree between multiplie empirical distributions

Parameters
----------
tree : array_like, shape(n)
ancestor of each node in the tree (ancestor of root is root)
length : array_like, shape(n)
length of the arc above each node (length of root is 0)
measure : array_like, shape(m, n)
distributions in the tree
weights : array_like, shape(m)
weight of each distribution

Returns
-------
barycenter : array_like, shape(n)
distribution of the barycenter

Reference
---------
The code is a direct implementation of the algorithm described in
Tree-Wasserstein Barycenter for Large-Scale Multilevel Clustering and Scalable Bayes

"""
n_measure = measure.shape[0]
n_node = tree.shape[0]

assert n_measure == weights.shape[0], "dimension error"

nx = get_backend(measure, weights, length)

z_measure = nx.zeros((n_measure, n_node))

if topo_order is None:
topo_order = topological_sort(tree)

for cur_node in topo_order:
p = tree[cur_node]

for id_mes in range(n_measure):
z_measure[id_mes][cur_node] += measure[id_mes][cur_node]

if cur_node != p:
z_measure[id_mes][p] += z_measure[id_mes][cur_node]

z = nx.zeros(n_node)

for cur_node in range(n_node):
z_measure[:, cur_node] *= length[cur_node]

z[cur_node] = wgm(z_measure[:, cur_node], weights)

return get_measure(z, tree, length)
Loading