【问题标题】:SQL to calculate the Tanimoto Coefficient of several vectorsSQL计算几个向量的谷本系数
【发布时间】:2023-03-21 21:48:01
【问题描述】:

我认为用一个例子来解释我的问题会更容易。

我有一张包含食谱成分的表格,并且我实现了一个函数来计算成分之间的Tanimoto coefficient。计算两种成分之间的系数已经足够快了(需要 3 个 sql 查询),但它不能很好地扩展。要计算所有可能成分组合之间的系数,它需要 N + (N*(N-1))/2 次查询或 500500 次查询,只需 1k 种成分。有没有更快的方法来做到这一点?到目前为止,这是我得到的:

class Filtering():
  def __init__(self):
    self._connection=sqlite.connect('database.db')

  def n_recipes(self, ingredient_id):
    cursor = self._connection.cursor()
    cursor.execute('''select count(recipe_id) from recipe_ingredient
        where ingredient_id = ? ''', (ingredient_id, ))
    return cursor.fetchone()[0]

  def n_recipes_intersection(self, ingredient_a, ingredient_b):
    cursor = self._connection.cursor()
    cursor.execute('''select count(drink_id) from recipe_ingredient where
        ingredient_id = ? and recipe_id in (
        select recipe_id from recipe_ingredient
        where ingredient_id = ?) ''', (ingredient_a, ingredient_b))
    return cursor.fetchone()[0]

  def tanimoto(self, ingredient_a, ingredient_b):
    n_a, n_b = map(self.n_recipes, (ingredient_a, ingredient_b))
    n_ab = self.n_recipes_intersection(ingredient_a, ingredient_b)
    return float(n_ab) / (n_a + n_b - n_ab)

【问题讨论】:

  • 真的很好奇你为什么选择使用谷本而不是余弦或其他相似度算法。我正在考虑执行类似的计算,很想听听您的理由。

标签: python sql collaborative-filtering


【解决方案1】:

您为什么不简单地将所有配方提取到内存中,然后在内存中计算 Tanimoto 系数?

它更简单,而且速度更快。

【讨论】:

  • 这是我的第一个想法,但您将如何实现它?循环遍历所有食谱的成分并为找到的每种成分和组合增加计数器?我在数据库中有超过 60k 项,所以即使这样也需要一些时间。
  • 捂脸!事实证明,这种方法比我想象的要快得多。计算所有系数只用了 4 秒。谢谢。
  • 一般来说,这是我的经验。人们写了太多的 SQL。
【解决方案2】:

如果有人感兴趣,这是我在 Alex 和 S.Lotts 的建议后提出的代码。谢谢各位。

def __init__(self):
    self._connection=sqlite.connect('database.db')
    self._counts = None
    self._intersections = {}

def inc_intersections(self, ingredients):
    ingredients.sort()
    lenght = len(ingredients)
    for i in xrange(1, lenght):
        a = ingredients[i]
        for j in xrange(0, i):
            b = ingredients[j]
            if a not in self._intersections:
                self._intersections[a] = {b: 1}
            elif b not in self._intersections[a]:
                self._intersections[a][b] = 1
            else:
                self._intersections[a][b] += 1


def precompute_tanimoto(self):
    counts = {}
    self._intersections = {}

    cursor = self._connection.cursor()
    cursor.execute('''select recipe_id, ingredient_id
        from recipe_ingredient
        order by recipe_id, ingredient_id''')
    rows = cursor.fetchall()            

    print len(rows)

    last_recipe = None
    for recipe, ingredient in rows:
        if recipe != last_recipe:
            if last_recipe != None:
                self.inc_intersections(ingredients)
            last_recipe = recipe
            ingredients = [ingredient]
        else:
            ingredients.append(ingredient)

        if ingredient not in counts:
            counts[ingredient] = 1
        else:
            counts[ingredient] += 1

    self.inc_intersections(ingredients)

    self._counts = counts

def tanimoto(self, ingredient_a, ingredient_b):
    if self._counts == None:
        self.precompute_tanimoto()

    if ingredient_b > ingredient_a:
        ingredient_b, ingredient_a = ingredient_a, ingredient_b

    n_a, n_b = self._counts[ingredient_a], self._counts[ingredient_b]
    n_ab = self._intersections[ingredient_a][ingredient_b]

    print n_a, n_b, n_ab

    return float(n_ab) / (n_a + n_b - n_ab)

【讨论】:

    【解决方案3】:

    如果您有 1000 种成分,则 1000 次查询足以将每种成分映射到内存中的一组食谱。如果(比方说)一种成分通常是大约 100 个食谱的一部分,那么每组将占用几 KB,因此整个字典将只占用几 MB——将整个内容保存在内存中绝对没有问题(而且仍然不严重如果每种成分的平均食谱数量增长了一个数量级,则会出现内存问题)。

    result = dict()
    for ing_id in all_ingredient_ids:
        cursor.execute('''select recipe_id from recipe_ingredient
            where ingredient_id = ?''', (ing_id,))
        result[ing_id] = set(r[0] for r in cursor.fetchall())
    return result
    

    在这 1000 次查询之后,所需的 500,000 次成对 Tanimoto 系数计算中的每一次显然都是在内存中完成的——您可以预先计算各种集合的长度的平方以进一步加速(并将它们停放在另一个dict),每对的关键“A dotproduct B”组件当然是集合的交集的长度。

    【讨论】:

    • 谢谢亚历克斯! +1 你的好建议,但我设法在内存中完成整个计算,一次获取所有数据。整个过程不到 4 秒。
    【解决方案4】:

    我认为这会将您减少到每对交叉点的 2 个选择,以及每对总共 4 个查询。您无法摆脱 O(N^2) ,因为您正在尝试所有对 - N*(N-1)/2 只是有多少对。

    def n_recipes_intersection(self, ingredient_a, ingredient_b):
      cursor = self._cur
      cursor.execute('''
        select count(recipe_id)
          from recipe_ingredient as A 
            join recipe_ingredient as B using (recipe_id)
          where A.ingredient_id = ? 
            and B.ingredient_id = ?;
          ''', (ingredient_a, ingredient_b))
      return cursor.fetchone()[0]
    

    【讨论】:

      猜你喜欢
      • 2023-01-14
      • 1970-01-01
      • 1970-01-01
      • 2016-06-16
      • 2015-08-23
      • 2020-01-03
      • 1970-01-01
      • 1970-01-01
      • 2012-06-11
      相关资源
      最近更新 更多