2019 年 8 月更新
这是另一个简单的实现,对于适度的尺寸来说非常快。假设输入点是唯一的。
def keep_efficient(pts):
'returns Pareto efficient row subset of pts'
# sort points by decreasing sum of coordinates
pts = pts[pts.sum(1).argsort()[::-1]]
# initialize a boolean mask for undominated points
# to avoid creating copies each iteration
undominated = np.ones(pts.shape[0], dtype=bool)
for i in range(pts.shape[0]):
# process each point in turn
n = pts.shape[0]
if i >= n:
break
# find all points not dominated by i
# since points are sorted by coordinate sum
# i cannot dominate any points in 1,...,i-1
undominated[i+1:n] = (pts[i+1:] >= pts[i]).any(1)
# keep points undominated so far
pts = pts[undominated[:n]]
return pts
我们首先根据坐标总和对点进行排序。这很有用,因为
- 对于许多数据分布,具有最大坐标和的点将支配大量点。
- 如果点
x的坐标和大于点y,则y不能支配x。
这里有一些与 Peter 的回答相关的基准,使用 np.random.randn。
N=10000 d=2
keep_efficient
1.31 ms ± 11.6 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
is_pareto_efficient
6.51 ms ± 23.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
N=10000 d=3
keep_efficient
2.3 ms ± 13.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
is_pareto_efficient
16.4 ms ± 156 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
N=10000 d=4
keep_efficient
4.37 ms ± 38.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
is_pareto_efficient
21.1 ms ± 115 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
N=10000 d=5
keep_efficient
15.1 ms ± 491 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
is_pareto_efficient
110 ms ± 1.01 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
N=10000 d=6
keep_efficient
40.1 ms ± 211 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
is_pareto_efficient
279 ms ± 2.54 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
N=10000 d=15
keep_efficient
3.92 s ± 125 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
is_pareto_efficient
5.88 s ± 74.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
凸壳启发式
我最近最终研究了这个问题,并发现了一个有用的启发式方法,如果有很多独立分布的点并且维度很少,它会很好地工作。
这个想法是计算点的凸包。由于维度少且点独立分布,凸包的顶点数会很少。直观地说,我们可以预期凸包的一些顶点支配许多原始点。此外,如果凸包中的一个点不受凸包中任何其他点的支配,那么它也不受原始集合中的任何点支配。
这给出了一个简单的迭代算法。我们反复
- 计算凸包。
- 从凸包中保存 Pareto 未支配点。
- 过滤点以去除那些由凸包元素支配的点。
我为维度 3 添加了一些基准。似乎对于某些点分布,这种方法产生了更好的渐近性。
import numpy as np
from scipy import spatial
from functools import reduce
# test points
pts = np.random.rand(10_000_000, 3)
def filter_(pts, pt):
"""
Get all points in pts that are not Pareto dominated by the point pt
"""
weakly_worse = (pts <= pt).all(axis=-1)
strictly_worse = (pts < pt).any(axis=-1)
return pts[~(weakly_worse & strictly_worse)]
def get_pareto_undominated_by(pts1, pts2=None):
"""
Return all points in pts1 that are not Pareto dominated
by any points in pts2
"""
if pts2 is None:
pts2 = pts1
return reduce(filter_, pts2, pts1)
def get_pareto_frontier(pts):
"""
Iteratively filter points based on the convex hull heuristic
"""
pareto_groups = []
# loop while there are points remaining
while pts.shape[0]:
# brute force if there are few points:
if pts.shape[0] < 10:
pareto_groups.append(get_pareto_undominated_by(pts))
break
# compute vertices of the convex hull
hull_vertices = spatial.ConvexHull(pts).vertices
# get corresponding points
hull_pts = pts[hull_vertices]
# get points in pts that are not convex hull vertices
nonhull_mask = np.ones(pts.shape[0], dtype=bool)
nonhull_mask[hull_vertices] = False
pts = pts[nonhull_mask]
# get points in the convex hull that are on the Pareto frontier
pareto = get_pareto_undominated_by(hull_pts)
pareto_groups.append(pareto)
# filter remaining points to keep those not dominated by
# Pareto points of the convex hull
pts = get_pareto_undominated_by(pts, pareto)
return np.vstack(pareto_groups)
# --------------------------------------------------------------------------------
# previous solutions
# --------------------------------------------------------------------------------
def is_pareto_efficient_dumb(costs):
"""
:param costs: An (n_points, n_costs) array
:return: A (n_points, ) boolean array, indicating whether each point is Pareto efficient
"""
is_efficient = np.ones(costs.shape[0], dtype = bool)
for i, c in enumerate(costs):
is_efficient[i] = np.all(np.any(costs>=c, axis=1))
return is_efficient
def is_pareto_efficient(costs):
"""
:param costs: An (n_points, n_costs) array
:return: A (n_points, ) boolean array, indicating whether each point is Pareto efficient
"""
is_efficient = np.ones(costs.shape[0], dtype = bool)
for i, c in enumerate(costs):
if is_efficient[i]:
is_efficient[is_efficient] = np.any(costs[is_efficient]<=c, axis=1) # Remove dominated points
return is_efficient
def dominates(row, rowCandidate):
return all(r >= rc for r, rc in zip(row, rowCandidate))
def cull(pts, dominates):
dominated = []
cleared = []
remaining = pts
while remaining:
candidate = remaining[0]
new_remaining = []
for other in remaining[1:]:
[new_remaining, dominated][dominates(candidate, other)].append(other)
if not any(dominates(other, candidate) for other in new_remaining):
cleared.append(candidate)
else:
dominated.append(candidate)
remaining = new_remaining
return cleared, dominated
# --------------------------------------------------------------------------------
# benchmarking
# --------------------------------------------------------------------------------
# to accomodate the original non-numpy solution
pts_list = [list(pt) for pt in pts]
import timeit
# print('Old non-numpy solution:s\t{}'.format(
# timeit.timeit('cull(pts_list, dominates)', number=3, globals=globals())))
print('Numpy solution:\t{}'.format(
timeit.timeit('is_pareto_efficient(pts)', number=3, globals=globals())))
print('Convex hull heuristic:\t{}'.format(
timeit.timeit('get_pareto_frontier(pts)', number=3, globals=globals())))
结果
# >>= python temp.py # 1,000 points
# Old non-numpy solution: 0.0316428339574486
# Numpy solution: 0.005961259012110531
# Convex hull heuristic: 0.012369581032544374
# >>= python temp.py # 1,000,000 points
# Old non-numpy solution: 70.67529802105855
# Numpy solution: 5.398462114972062
# Convex hull heuristic: 1.5286884519737214
# >>= python temp.py # 10,000,000 points
# Numpy solution: 98.03680767398328
# Convex hull heuristic: 10.203076395904645
原帖
我尝试通过一些调整来重写相同的算法。我认为您的大部分问题都来自inputPoints.remove(row)。这需要搜索点列表;按索引删除会更有效。
我还修改了dominates 函数以避免一些冗余比较。这在更高维度上可能会很方便。
def dominates(row, rowCandidate):
return all(r >= rc for r, rc in zip(row, rowCandidate))
def cull(pts, dominates):
dominated = []
cleared = []
remaining = pts
while remaining:
candidate = remaining[0]
new_remaining = []
for other in remaining[1:]:
[new_remaining, dominated][dominates(candidate, other)].append(other)
if not any(dominates(other, candidate) for other in new_remaining):
cleared.append(candidate)
else:
dominated.append(candidate)
remaining = new_remaining
return cleared, dominated