所以要解决这个问题,它有助于解决一个稍微不同的问题。我们想知道每行中总第 k 个截止点的上/下限。然后我们可以通过,验证下界或以下的事物数量k,并且它们之间只有一个值。
我想出了一个策略,在所有行中同时针对这些边界进行二进制搜索。作为二进制搜索,它“应该”通过O(log(n))。每遍涉及O(m)工作,共O(m log(n))次。我把应该放在引号中,因为我没有证据证明它实际上需要O(log(n)) 通行证。事实上,有可能在一行中过于激进,从其他行中发现所选择的枢轴已关闭,然后不得不退出。但我相信它几乎没有退缩,实际上是O(m log(n))。
策略是跟踪下限、上限和中间的每一行。每次传递,我们都会创建一系列加权范围,从低到中,从中到高,从高到末端,权重是其中事物的数量,值是系列中的最后一个。然后,我们在该数据结构中找到第 k 个值(按权重),并将其用作在每个维度中进行二分搜索的枢轴。
如果枢轴超出从下到上的范围,我们会通过在纠正错误的方向上扩大区间来进行纠正。
当我们有正确的顺序时,我们就有了答案。
有很多边缘情况,所以盯着完整的代码可能会有所帮助。
我还假设每一行的所有元素都是不同的。如果不是,您可能会陷入无限循环。 (解决这意味着更多的边缘情况......)
import random
# This takes (k, [(value1, weight1), (value2, weight2), ...])
def weighted_kth (k, pairs):
# This does quickselect for average O(len(pairs)).
# Median of medians is deterministically the same, but a bit slower
pivot = pairs[int(random.random() * len(pairs))][0]
# Which side of our answer is the pivot on?
weight_under_pivot = 0
pivot_weight = 0
for value, weight in pairs:
if value < pivot:
weight_under_pivot += weight
elif value == pivot:
pivot_weight += weight
if weight_under_pivot + pivot_weight < k:
filtered_pairs = []
for pair in pairs:
if pivot < pair[0]:
filtered_pairs.append(pair)
return weighted_kth (k - weight_under_pivot - pivot_weight, filtered_pairs)
elif k <= weight_under_pivot:
filtered_pairs = []
for pair in pairs:
if pair[0] < pivot:
filtered_pairs.append(pair)
return weighted_kth (k, filtered_pairs)
else:
return pivot
# This takes (k, [[...], [...], ...])
def kth_in_row_sorted_matrix (k, matrix):
# The strategy is to discover the k'th value, and also discover where
# that would be in each row.
#
# For each row we will track what we think the lower and upper bounds
# are on where it is. Those bounds start as the start and end and
# will do a binary search.
#
# In each pass we will break each row into ranges from start to lower,
# lower to mid, mid to upper, and upper to end. Some ranges may be
# empty. We will then create a weighted list of ranges with the weight
# being the length, and the value being the end of the list. We find
# where the k'th spot is in that list, and use that approximate value
# to refine each range. (There is a chance that a range is wrong, and
# we will have to deal with that.)
#
# We finish when all of the uppers are above our k, all the lowers
# one are below, and the upper/lower gap is more than 1 only when our
# k'th element is in the middle.
# Our data structure is simply [row, lower, upper, bound] for each row.
data = [[row, 0, min(k, len(row)-1), min(k, len(row)-1)] for row in matrix]
is_search = True
while is_search:
pairs = []
for row, lower, upper, bound in data:
# Literal edge cases
if 0 == upper:
pairs.append((row[upper], 1))
if upper < bound:
pairs.append((row[bound], bound - upper))
elif lower == bound:
pairs.append((row[lower], lower + 1))
elif lower + 1 == upper: # No mid.
pairs.append((row[lower], lower + 1))
pairs.append((row[upper], 1))
if upper < bound:
pairs.append((row[bound], bound - upper))
else:
mid = (upper + lower) // 2
pairs.append((row[lower], lower + 1))
pairs.append((row[mid], mid - lower))
pairs.append((row[upper], upper - mid))
if upper < bound:
pairs.append((row[bound], bound - upper))
pivot = weighted_kth(k, pairs)
# Now that we have our pivot, we try to adjust our parameters.
# If any adjusts we continue our search.
is_search = False
new_data = []
for row, lower, upper, bound in data:
# First cases where our bounds weren't bounds for our pivot.
# We rebase the interval and either double the range.
# - double the size of the range
# - go halfway to the edge
if 0 < lower and pivot <= row[lower]:
is_search = True
if pivot == row[lower]:
new_data.append((row, lower-1, min(lower+1, bound), bound))
elif upper <= lower:
new_data.append((row, lower-1, lower, bound))
else:
new_data.append((row, max(lower // 2, lower - 2*(upper - lower)), lower, bound))
elif upper < bound and row[upper] <= pivot:
is_search = True
if pivot == row[upper]:
new_data.append((row, upper-1, upper+1, bound))
elif lower < upper:
new_data.append((row, upper, min((upper+bound+1)//2, upper + 2*(upper - lower)), bound))
else:
new_data.append((row, upper, upper+1, bound))
elif lower + 1 < upper:
if upper == lower+2 and pivot == row[lower+1]:
new_data.append((row, lower, upper, bound)) # Looks like we found the pivot.
else:
# We will split this interval.
is_search = True
mid = (upper + lower) // 2
if row[mid] < pivot:
new_data.append((row, mid, upper, bound))
elif pivot < row[mid] pivot:
new_data.append((row, lower, mid, bound))
else:
# We center our interval on the pivot
new_data.append((row, (lower+mid)//2, (mid+upper+1)//2, bound))
else:
# We look like we found where the pivot would be in this row.
new_data.append((row, lower, upper, bound))
data = new_data # And set up the next search
return pivot