From aff03a7cba3203adc62d2718ab59af794536687b Mon Sep 17 00:00:00 2001 From: Boudjema Ali Date: Wed, 10 Jun 2026 15:29:29 +0200 Subject: [PATCH 1/7] twd --- ot/lp/__init__.py | 7 +++ ot/lp/solver_tree.py | 103 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 110 insertions(+) create mode 100644 ot/lp/solver_tree.py diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index c9fa676c4..c42a0fefd 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -36,6 +36,11 @@ linear_circular_ot, ) +from .solver_tree import ( + topological_sort, + tree_wasserstein, +) + __all__ = [ "emd", "emd2", @@ -60,4 +65,6 @@ "free_support_barycenter_generic_costs", "NorthWestMMGluing", "ot_barycenter_energy", + "topological_sort", + "tree_wasserstein", ] diff --git a/ot/lp/solver_tree.py b/ot/lp/solver_tree.py new file mode 100644 index 000000000..243a37b01 --- /dev/null +++ b/ot/lp/solver_tree.py @@ -0,0 +1,103 @@ +from ..backend import get_backend +import numpy as np +from collections import deque + +""" +Solver for the tree wasserstein distance problem +""" + +# Author : Ali Boudjema + + +def topological_sort(tree): + r""" + Computes a topological sort of the given tree + The tree is an array : tree[i] is the direct ancestor of node i, and tree[root] = root + + Je ne vérifie pas que l'arbre est du bon format + Pas de backend ici ? + """ + + n = len(tree) + + in_degree = np.zeros(n) + + for cur_node in range(n): + if cur_node != tree[cur_node]: + in_degree[tree[cur_node]] += 1 + + queue = deque() + + for cur_node in range(n): + if in_degree[cur_node] == 0: + queue.append(cur_node) + + topo_order = [] + + while queue: + cur_node = queue.popleft() + topo_order.append(cur_node) + + if cur_node != tree[cur_node]: + in_degree[tree[cur_node]] -= 1 + + if in_degree[tree[cur_node]] == 0: + queue.append(tree[cur_node]) + + return np.array(topo_order) + + +def tree_wasserstein(tree, length, u_weights, v_weights, topo_order=None): + r""" + Computes the tree wasserstein distance for a given tree between two empirical distributions + + Parameters + ---------- + tree : array_like, shape(n) + ancestor of each node in the tree (ancestor of root is root) + length : array_like, shape(n) + length of the arc above each node (length of root is 0) + u_weights : array_like, shape(n, ...) + weights of the first empirical distributions + v_weights : array_like, shape(n, ...) + weights of the second empirical distributions + topo_order : array_like, shape(n) + topological order of the tree, optional + + Returns + ------- + cost : float/array_like, shape(...) + The tree wasserstein distance + """ + + n = len(tree) + + assert ( + n == len(length) == u_weights.shape[0] == v_weights.shape[0] + ), "dimension error in the input" + + if topo_order is None: + topo_order = topological_sort(tree) + + nx = get_backend(length, u_weights, v_weights) + + u_cumweights = nx.copy(u_weights) + v_cumweights = nx.copy(v_weights) + + cost = 0 + + for i in range(n): + cur_node = topo_order[i] + + cost += length[cur_node] * nx.abs( + u_cumweights[cur_node] - v_cumweights[cur_node] + ) + + u_cumweights[tree[cur_node]] = ( + u_cumweights[cur_node] + u_cumweights[tree[cur_node]] + ) + v_cumweights[tree[cur_node]] = ( + v_cumweights[cur_node] + v_cumweights[tree[cur_node]] + ) + + return cost From 196736a5aab234d522fe65243fece67a946d5453 Mon Sep 17 00:00:00 2001 From: Boudjema Ali Date: Thu, 11 Jun 2026 09:56:56 +0200 Subject: [PATCH 2/7] twd --- ot/lp/solver_tree.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/ot/lp/solver_tree.py b/ot/lp/solver_tree.py index 243a37b01..781a27480 100644 --- a/ot/lp/solver_tree.py +++ b/ot/lp/solver_tree.py @@ -11,14 +11,15 @@ def topological_sort(tree): r""" - Computes a topological sort of the given tree - The tree is an array : tree[i] is the direct ancestor of node i, and tree[root] = root + Computes a topological order of the given tree - Je ne vérifie pas que l'arbre est du bon format - Pas de backend ici ? + Parameters + ----------- + tree: array_like, shape(n, ...) + ancestor of each node in the tree (ancestor of root is root) """ - n = len(tree) + n = tree.shape[0] in_degree = np.zeros(n) @@ -53,9 +54,9 @@ def tree_wasserstein(tree, length, u_weights, v_weights, topo_order=None): Parameters ---------- - tree : array_like, shape(n) + tree : array_like, shape(n, ...) ancestor of each node in the tree (ancestor of root is root) - length : array_like, shape(n) + length : array_like, shape(n, ...) length of the arc above each node (length of root is 0) u_weights : array_like, shape(n, ...) weights of the first empirical distributions @@ -70,10 +71,10 @@ def tree_wasserstein(tree, length, u_weights, v_weights, topo_order=None): The tree wasserstein distance """ - n = len(tree) + n = tree.shape[0] assert ( - n == len(length) == u_weights.shape[0] == v_weights.shape[0] + n == length.shape[0] == u_weights.shape[0] == v_weights.shape[0] ), "dimension error in the input" if topo_order is None: From 761db8f39b0a6c019833f8a42e5270c426124477 Mon Sep 17 00:00:00 2001 From: Boudjema Ali Date: Mon, 15 Jun 2026 15:43:49 +0200 Subject: [PATCH 3/7] twd --- ot/lp/solver_tree.py | 36 ++++++++++++++++++------------------ 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/ot/lp/solver_tree.py b/ot/lp/solver_tree.py index 781a27480..d9a2e59be 100644 --- a/ot/lp/solver_tree.py +++ b/ot/lp/solver_tree.py @@ -3,7 +3,7 @@ from collections import deque """ -Solver for the tree wasserstein distance problem +Solver for the tree wasserstein distance """ # Author : Ali Boudjema @@ -15,13 +15,13 @@ def topological_sort(tree): Parameters ----------- - tree: array_like, shape(n, ...) + tree: array_like, shape(n) ancestor of each node in the tree (ancestor of root is root) """ n = tree.shape[0] - in_degree = np.zeros(n) + in_degree = np.zeros(n, dtype=int) for cur_node in range(n): if cur_node != tree[cur_node]: @@ -39,11 +39,13 @@ def topological_sort(tree): cur_node = queue.popleft() topo_order.append(cur_node) - if cur_node != tree[cur_node]: - in_degree[tree[cur_node]] -= 1 + ancestor = tree[cur_node] + + if cur_node != ancestor: + in_degree[ancestor] -= 1 - if in_degree[tree[cur_node]] == 0: - queue.append(tree[cur_node]) + if in_degree[ancestor] == 0: + queue.append(ancestor) return np.array(topo_order) @@ -54,20 +56,20 @@ def tree_wasserstein(tree, length, u_weights, v_weights, topo_order=None): Parameters ---------- - tree : array_like, shape(n, ...) + tree : array_like, shape(n) ancestor of each node in the tree (ancestor of root is root) - length : array_like, shape(n, ...) + length : array_like, shape(n) length of the arc above each node (length of root is 0) - u_weights : array_like, shape(n, ...) + u_weights : array_like, shape(n) weights of the first empirical distributions - v_weights : array_like, shape(n, ...) + v_weights : array_like, shape(n) weights of the second empirical distributions topo_order : array_like, shape(n) topological order of the tree, optional Returns ------- - cost : float/array_like, shape(...) + cost : float The tree wasserstein distance """ @@ -94,11 +96,9 @@ def tree_wasserstein(tree, length, u_weights, v_weights, topo_order=None): u_cumweights[cur_node] - v_cumweights[cur_node] ) - u_cumweights[tree[cur_node]] = ( - u_cumweights[cur_node] + u_cumweights[tree[cur_node]] - ) - v_cumweights[tree[cur_node]] = ( - v_cumweights[cur_node] + v_cumweights[tree[cur_node]] - ) + ancestor = tree[cur_node] + + u_cumweights[ancestor] = u_cumweights[cur_node] + u_cumweights[ancestor] + v_cumweights[ancestor] = v_cumweights[cur_node] + v_cumweights[ancestor] return cost From e2203cf5f9cb6b4ac5b5889aa4c6da4d228df7f0 Mon Sep 17 00:00:00 2001 From: Boudjema Ali Date: Tue, 16 Jun 2026 15:53:34 +0200 Subject: [PATCH 4/7] plan transport --- ot/lp/solver_tree.py | 100 +++++++++++++++++++++++++++++++++++++------ 1 file changed, 87 insertions(+), 13 deletions(-) diff --git a/ot/lp/solver_tree.py b/ot/lp/solver_tree.py index d9a2e59be..92a470826 100644 --- a/ot/lp/solver_tree.py +++ b/ot/lp/solver_tree.py @@ -50,7 +50,9 @@ def topological_sort(tree): return np.array(topo_order) -def tree_wasserstein(tree, length, u_weights, v_weights, topo_order=None): +def tree_wasserstein( + tree, length, u_weights, v_weights, topo_order=None, return_plans=False +): r""" Computes the tree wasserstein distance for a given tree between two empirical distributions @@ -64,13 +66,18 @@ def tree_wasserstein(tree, length, u_weights, v_weights, topo_order=None): weights of the first empirical distributions v_weights : array_like, shape(n) weights of the second empirical distributions - topo_order : array_like, shape(n) - topological order of the tree, optional + topo_order : array_like, shape(n), optional + topological order of the tree + return_plans : bool, optional + if True, returns the optimal transport plan between the + two distributions, default is False Returns ------- cost : float The tree wasserstein distance + plans : coo_matrix, optional + If return_plans is True, returns a coo_matrix containing the plan """ n = tree.shape[0] @@ -84,21 +91,88 @@ def tree_wasserstein(tree, length, u_weights, v_weights, topo_order=None): nx = get_backend(length, u_weights, v_weights) - u_cumweights = nx.copy(u_weights) - v_cumweights = nx.copy(v_weights) + mass_dict = {} + + for cur in range(n): + if u_weights[cur] != v_weights[cur]: + mass_dict[cur] = {cur: u_weights[cur] - v_weights[cur]} + else: + mass_dict[cur] = {} + + source_plan = [] + sink_plan = [] + mass_plan = [] + + virt_size = [len(mass_dict[k]) for k in range(n)] cost = 0 - for i in range(n): + depth = nx.zeros(n) + + for i in range(n - 2, -1, -1): cur_node = topo_order[i] + depth[cur_node] = depth[tree[cur_node]] + length[cur_node] - cost += length[cur_node] * nx.abs( - u_cumweights[cur_node] - v_cumweights[cur_node] - ) + for cur in topo_order: + dict_cur = mass_dict[cur] + p = tree[cur] - ancestor = tree[cur_node] + if cur != p: + dict_p = mass_dict[p] + + if virt_size[cur] > virt_size[p]: + mass_dict[cur], mass_dict[p] = dict_p, dict_cur + dict_cur, dict_p = dict_p, dict_cur + virt_size[cur], virt_size[p] = virt_size[p], virt_size[cur] + + while len(dict_cur) > 0 and len(dict_p) > 0: + node_scur = next(iter(dict_cur)) + amount_scur = dict_cur[node_scur] + + node_sp = next(iter(dict_p)) + amount_sp = dict_p[node_sp] + + if (amount_scur > 0) != (amount_sp > 0): + match_amount = min(abs(amount_scur), abs(amount_sp)) + + source = node_scur if amount_scur > 0 else node_sp + sink = node_sp if amount_scur > 0 else node_scur + + source_plan.append(source) + sink_plan.append(sink) + mass_plan.append(match_amount) + + length_path = depth[source] + depth[sink] - 2 * depth[p] + cost += match_amount * length_path + + if amount_scur > 0: + dict_cur[node_scur] -= match_amount + dict_p[node_sp] += match_amount + else: + dict_cur[node_scur] += match_amount + dict_p[node_sp] -= match_amount + + if dict_cur[node_scur] == 0: + del dict_cur[node_scur] + + if dict_p[node_sp] == 0: + del dict_p[node_sp] + + else: + dict_p[node_scur] = amount_scur + del dict_cur[node_scur] + + if len(dict_p) == 0: + mass_dict[cur], mass_dict[p] = dict_p, dict_cur + dict_cur, dict_p = dict_p, dict_cur + + virt_size[p] += virt_size[cur] - u_cumweights[ancestor] = u_cumweights[cur_node] + u_cumweights[ancestor] - v_cumweights[ancestor] = v_cumweights[cur_node] + v_cumweights[ancestor] + plans = nx.coo_matrix( + mass_plan, source_plan, sink_plan, shape=(n, n), type_as=length + ) - return cost + if return_plans: + return cost, plans + else: + return cost From 361938c11d9bfd07ea3b6a68edd0cb36fc5671c5 Mon Sep 17 00:00:00 2001 From: Boudjema Ali Date: Wed, 17 Jun 2026 15:02:40 +0200 Subject: [PATCH 5/7] barycenter --- ot/lp/__init__.py | 5 +++ ot/lp/solver_tree.py | 5 +++ ot/lp/tree_barycenter.py | 93 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 103 insertions(+) create mode 100644 ot/lp/tree_barycenter.py diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index c42a0fefd..99c9c5cd5 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -41,6 +41,10 @@ tree_wasserstein, ) +from .tree_barycenter import ( + tree_barycenter, +) + __all__ = [ "emd", "emd2", @@ -67,4 +71,5 @@ "ot_barycenter_energy", "topological_sort", "tree_wasserstein", + "tree_barycenter", ] diff --git a/ot/lp/solver_tree.py b/ot/lp/solver_tree.py index 92a470826..498d4778c 100644 --- a/ot/lp/solver_tree.py +++ b/ot/lp/solver_tree.py @@ -78,6 +78,11 @@ def tree_wasserstein( The tree wasserstein distance plans : coo_matrix, optional If return_plans is True, returns a coo_matrix containing the plan + + Reference + --------- + The proof of this algorithm uses the formula (3) in the article + Tree-Sliced Variants of Wasserstein Distances """ n = tree.shape[0] diff --git a/ot/lp/tree_barycenter.py b/ot/lp/tree_barycenter.py new file mode 100644 index 000000000..046559aae --- /dev/null +++ b/ot/lp/tree_barycenter.py @@ -0,0 +1,93 @@ +from ..backend import get_backend + + +def wgm(values, weights): + # Returns the weighted geometric median + + nx = get_backend(values, weights) + + sorted_indices = nx.argsort(values) + + values_sorted = values[sorted_indices] + weights_sorted = weights[sorted_indices] + + cum_weights = nx.cumsum(weights_sorted) + + id = nx.searchsorted(cum_weights, 0.5) + + return values_sorted[id] + + +def get_measure(z, tree, length): + # Retrieves the measure from a vector after the wgm + + n = z.shape[0] + + nx = get_backend(length) + + measure = nx.zeros(n) + + for i in range(n): + p = tree[i] + + if i == p: + measure[i] += 1 + else: + measure[i] += z[i] / length[i] + measure[p] -= z[i] / length[i] + + return measure + + +def tree_barycenter(tree, length, measure, weights): + r""" + Computes the tree wasserstein barycenter for a given tree between multiplie empirical distributions + + Parameters + ---------- + tree : array_like, shape(n) + ancestor of each node in the tree (ancestor of root is root) + length : array_like, shape(n) + length of the arc above each node (length of root is 0) + measure : array_like, shape(m, n) + distributions in the tree + weights : array_like, shape(m) + weight of each distribution + + Returns + ------- + barycenter : array_like, shape(n) + distribution of the barycenter + + Reference + --------- + The code is a direct implementation of the algorithm described in + Tree-Wasserstein Barycenter for Large-Scale Multilevel Clustering and Scalable Bayes + + """ + n_measure = measure.shape[0] + n_node = tree.shape[0] + + assert n_measure == weights.shape[0], "dimension error" + + nx = get_backend(measure, weights, length) + + z_measure = nx.zeros((n_measure, n_node)) + + for cur_node in range(n_node): + p = tree[cur_node] + + for id_mes in range(n_measure): + z_measure[id_mes][cur_node] += measure[id_mes][cur_node] + + if cur_node != p: + z_measure[id_mes][p] += measure[id_mes][cur_node] + + z = nx.zeros(n_node) + + for cur_node in range(n_node): + z_measure[:, cur_node] *= length[cur_node] + + z[cur_node] = wgm(z_measure[:, cur_node], weights) + + return get_measure(z, tree, length) From b736caa0b74aa21e307bb4b27a209e3553928a29 Mon Sep 17 00:00:00 2001 From: Boudjema Ali Date: Fri, 19 Jun 2026 11:12:04 +0200 Subject: [PATCH 6/7] correction bug barycenter --- ot/lp/tree_barycenter.py | 59 ++++++++++++++++++++++++++++++++++++---- 1 file changed, 54 insertions(+), 5 deletions(-) diff --git a/ot/lp/tree_barycenter.py b/ot/lp/tree_barycenter.py index 046559aae..c25ea5542 100644 --- a/ot/lp/tree_barycenter.py +++ b/ot/lp/tree_barycenter.py @@ -1,4 +1,50 @@ from ..backend import get_backend +import numpy as np +from collections import deque + +# Author : Ali Boudjema + + +# A retirer +def topological_sort(tree): + r""" + Computes a topological order of the given tree + + Parameters + ----------- + tree: array_like, shape(n) + ancestor of each node in the tree (ancestor of root is root) + """ + + n = tree.shape[0] + + in_degree = np.zeros(n, dtype=int) + + for cur_node in range(n): + if cur_node != tree[cur_node]: + in_degree[tree[cur_node]] += 1 + + queue = deque() + + for cur_node in range(n): + if in_degree[cur_node] == 0: + queue.append(cur_node) + + topo_order = [] + + while queue: + cur_node = queue.popleft() + topo_order.append(cur_node) + + ancestor = tree[cur_node] + + if cur_node != ancestor: + in_degree[ancestor] -= 1 + + if in_degree[ancestor] == 0: + queue.append(ancestor) + + return np.array(topo_order) def wgm(values, weights): @@ -6,14 +52,14 @@ def wgm(values, weights): nx = get_backend(values, weights) - sorted_indices = nx.argsort(values) + sorted_indices = np.argsort(values, kind="stable") values_sorted = values[sorted_indices] weights_sorted = weights[sorted_indices] cum_weights = nx.cumsum(weights_sorted) - id = nx.searchsorted(cum_weights, 0.5) + id = nx.searchsorted(cum_weights, 0.5 - 1e9) return values_sorted[id] @@ -39,7 +85,7 @@ def get_measure(z, tree, length): return measure -def tree_barycenter(tree, length, measure, weights): +def tree_barycenter(tree, length, measure, weights, topo_order=None): r""" Computes the tree wasserstein barycenter for a given tree between multiplie empirical distributions @@ -74,14 +120,17 @@ def tree_barycenter(tree, length, measure, weights): z_measure = nx.zeros((n_measure, n_node)) - for cur_node in range(n_node): + if topo_order is None: + topo_order = topological_sort(tree) + + for cur_node in topo_order: p = tree[cur_node] for id_mes in range(n_measure): z_measure[id_mes][cur_node] += measure[id_mes][cur_node] if cur_node != p: - z_measure[id_mes][p] += measure[id_mes][cur_node] + z_measure[id_mes][p] += z_measure[id_mes][cur_node] z = nx.zeros(n_node) From f81fa30f0118164c15a18e8d818e53b678abb444 Mon Sep 17 00:00:00 2001 From: Boudjema Ali Date: Fri, 19 Jun 2026 11:37:05 +0200 Subject: [PATCH 7/7] tree barycenter --- ot/lp/tree_barycenter.py | 44 +--------------------------------------- 1 file changed, 1 insertion(+), 43 deletions(-) diff --git a/ot/lp/tree_barycenter.py b/ot/lp/tree_barycenter.py index c25ea5542..38f1e49e1 100644 --- a/ot/lp/tree_barycenter.py +++ b/ot/lp/tree_barycenter.py @@ -1,52 +1,10 @@ from ..backend import get_backend import numpy as np -from collections import deque +from .solver_tree import topological_sort # Author : Ali Boudjema -# A retirer -def topological_sort(tree): - r""" - Computes a topological order of the given tree - - Parameters - ----------- - tree: array_like, shape(n) - ancestor of each node in the tree (ancestor of root is root) - """ - - n = tree.shape[0] - - in_degree = np.zeros(n, dtype=int) - - for cur_node in range(n): - if cur_node != tree[cur_node]: - in_degree[tree[cur_node]] += 1 - - queue = deque() - - for cur_node in range(n): - if in_degree[cur_node] == 0: - queue.append(cur_node) - - topo_order = [] - - while queue: - cur_node = queue.popleft() - topo_order.append(cur_node) - - ancestor = tree[cur_node] - - if cur_node != ancestor: - in_degree[ancestor] -= 1 - - if in_degree[ancestor] == 0: - queue.append(ancestor) - - return np.array(topo_order) - - def wgm(values, weights): # Returns the weighted geometric median