【发布时间】:2015-01-21 19:15:02
【问题描述】:
def _partition_by_context(self, labels, contexts):
# partition the labels by context
assert len(labels) == len(contexts)
by_context = collections.defaultdict(list)
for i, label in enumerate(labels):
by_context[contexts[i]].append(label)
# now remove any that don't have enough samples
keys_to_remove = []
for key, value in by_context.iteritems():
if len(value) < self._min_samples_context:
keys_to_remove.append(key)
for key in keys_to_remove:
del by_context[key]
return by_context
- labels 是一个 numpy 浮点数组。
- contexts 是一个 Python 元组列表。每个元组的形式为
(unicode, int):例如(u'ffcd6881167b47d492adf3f542af94c6', 2)。上下文值经常重复。例如,上下文列表中可能有 10000 个值,但只有 100 个不同的值。 -
len(labels) == len(contexts)为真,如第一行所述 - 索引 i 处的标签与索引 i 处的上下文相关联。也就是
labels[i]和contexts[i]“一起去”
这个功能的重点是按上下文值对标签中的值进行分区。最后,如果标签计数太低,则删除字典条目。
因此,如果所有上下文值都相同,则返回值将是具有单个条目的字典,key=context, value=所有标签的列表。
如果有 N 个不同的上下文值,则返回值将有 N 个键(每个上下文一个),每个键的值将是与特定上下文关联的标签列表。列表中标签的顺序并不重要。
这个函数用不同的参数被调用了数百万次。我已经确定这是使用gprof2dot 的瓶颈。大部分成本都在第一个 for 循环中的列表 append() 调用中。
谢谢!
【问题讨论】:
-
这个问题更适合Code Review
-
由于您只对标签的数量感兴趣,因此请尝试使用带有 int 而不是 list 的 defaultdict,并且只需递增而不是 append。
-
@jpkotta:我不仅对每个上下文的标签数量感兴趣。我需要每个上下文的所有标签,所以我可以计算每个上下文标签的均值和标准差。
标签: python performance numpy