【问题标题】:Bleu score in python from scratch从零开始在 python 中的 Bleu 分数
【发布时间】:2019-11-19 22:43:41
【问题描述】:

看了吴恩达关于Bleu score 的视频后,我想用python 从头开始​​实现一个。我用 numpy 谨慎地用 python 编写了完整的代码。这是完整的代码

import numpy as np

def n_gram_generator(sentence,n= 2,n_gram= False):
    '''
    N-Gram generator with parameters sentence
    n is for number of n_grams
    The n_gram parameter removes repeating n_grams 
    '''
    sentence = sentence.lower() # converting to lower case
    sent_arr = np.array(sentence.split()) # split to string arrays
    length = len(sent_arr)

    word_list = []
    for i in range(length+1):
        if i < n:
            continue
        word_range = list(range(i-n,i))
        s_list = sent_arr[word_range]
        string = ' '.join(s_list) # converting list to strings
        word_list.append(string) # append to word_list
        if n_gram:
            word_list = list(set(word_list))
    return word_list

def bleu_score(original,machine_translated):
    '''
    Bleu score function given a orginal and a machine translated sentences
    '''
    mt_length = len(machine_translated.split())
    o_length = len(original.split())

    # Brevity Penalty 
    if mt_length>o_length:
        BP=1
    else:
        penality=1-(mt_length/o_length)
        BP=np.exp(penality)

    # calculating precision
    precision_score = []
    for i in range(mt_length):
        original_n_gram = n_gram_generator(original,i)
        machine_n_gram = n_gram_generator(machine_translated,i)
        n_gram_list = list(set(machine_n_gram)) # removes repeating strings

        # counting number of occurence 
        machine_score = 0
        original_score = 0
        for j in n_gram_list:
            machine_count = machine_n_gram.count(j)
            original_count = original_n_gram.count(j)
            machine_score = machine_score+machine_count
            original_score = original_score+original_count

        precision = original_score/machine_score
        precision_score.append(precision)
    precisions_sum = np.array(precision_score).sum()
    avg_precisions_sum=precisions_sum/mt_length
    bleu=BP*np.exp(avg_precisions_sum)
    return bleu

if __name__ == "__main__":
    original = "this is a test"
    bs=bleu_score(original,original)
    print("Bleu Score Original",bs)

我尝试用 nltk's 测试我的分数

from nltk.translate.bleu_score import sentence_bleu
reference = [['this', 'is', 'a', 'test']]
candidate = ['this', 'is', 'a', 'test']
score = sentence_bleu(reference, candidate)
print(score)

问题是我的 bleu 分数大约是 2.718281 而 nltk 的分数是 1。我究竟做错了什么?

以下是一些可能的原因:

1) 我根据机器翻译句子的长度计算了 ngrams。这里从 1 到 4

2) n_gram_generator 自己编写的函数,不确定其准确性

3) 我使用了错误的函数或计算错误的 bleu 分数的一些方法

有人可以查看我的代码并告诉我哪里出错了吗?

【问题讨论】:

    标签: python machine-learning nlp nltk


    【解决方案1】:

    这里是修改后的解决方案

    # coding: utf-8
    
    import numpy as np
    from collections import Counter
    import math
    from nltk.translate.bleu_score import sentence_bleu
    
    
    def n_gram_generator(sentence,n= 2,n_gram= False):
        '''
        N-Gram generator with parameters sentence
        n is for number of n_grams
        The n_gram parameter removes repeating n_grams
        '''
        sentence = sentence.lower()  # converting to lower case
        sent_arr = np.array(sentence.split())  # split to string arrays
        length = len(sent_arr)
    
        word_list = []
        for i in range(length+1):
            if i < n:
                continue
            word_range = list(range(i-n,i))
            s_list = sent_arr[word_range]
            string = ' '.join(s_list)  # converting list to strings
            word_list.append(string) # append to word_list
            if n_gram:
                word_list = list(set(word_list))
        return word_list
    
    
    def bleu_score(original, machine_translated):
        '''
        Bleu score function given a orginal and a machine translated sentences
        '''
        mt_length = len(machine_translated.split())
        o_length  = len(original.split())
    
        # Brevity Penalty
        if mt_length > o_length:
            BP=1
        else:
            penality=1-(mt_length/o_length)
            BP = np.exp(penality)
    
        # Clipped precision
        clipped_precision_score = []
        for ngram_level in range(1, 5):  # 1-gram to 4-gram
            
            
            original_ngram_list = n_gram_generator(original, ngram_level)
            original_n_gram = Counter(original_ngram_list)
            
            machine_ngram_list = n_gram_generator(machine_translated, ngram_level)
            machine_n_gram = Counter(machine_ngram_list)
            
            
            num_ngrams_in_translation = sum(machine_n_gram.values())  # number of ngrams in translation
            
            # iterate the unique ngrams in translation (candidate)
            for j in machine_n_gram:
                
                if j in original_n_gram:  # if found in reference
                    
                    if machine_n_gram[j] > original_n_gram[j]:  # CLIPPING - if found in translation more than in source, clip
                        machine_n_gram[j] = original_n_gram[j]
                        
                else:
                    machine_n_gram[j] = 0
    
            #print (sum(machine_n_gram.values()), c)
            clipped_precision_score.append(float(sum(machine_n_gram.values())) / num_ngrams_in_translation)
    
        #print (clipped_precision_score)
    
        weights = [0.25]*4
    
        s = (w_i * math.log(p_i) for w_i, p_i in zip(weights, clipped_precision_score))
        s = BP * math.exp(math.fsum(s))
        return s
    
    original = "It is a guide to action which ensures that the military alwasy obeys the command of the party"
    machine_translated = "It is the guiding principle which guarantees the military forces alwasy being under the command of the party"
    
    print (bleu_score(original, machine_translated))
    print (sentence_bleu([original.split()], machine_translated.split()))
    

    【讨论】:

      【解决方案2】:

      这里是实际nltksource code的略微修改版本:

      def sentence_bleu_man(
          references,
          hypothesis,
          weights=(0.25, 0.25, 0.25, 0.25)):
      
          # compute modified precision for 1-4 ngrams
          p_numerators = Counter()  
          p_denominators = Counter()  
          hyp_lengths, ref_lengths = 0, 0
      
          for i, _ in enumerate(weights, start=1):
              p_i = modified_precision(references, hypothesis, i)
              p_numerators[i] += p_i.numerator
              p_denominators[i] += p_i.denominator
      
          # compute brevity penalty    
          hyp_len = len(hypothesis)
          ref_len = closest_ref_length(references, hyp_len)
          bp = brevity_penalty(ref_len, hyp_len)
      
          # compute final score
          p_n = [
              Fraction(p_numerators[i], p_denominators[i], 
              _normalize=False)
              for i, _ in enumerate(weights, start=1)
              if p_numerators[i] > 0
          ]
          s = (w_i * math.log(p_i) for w_i, p_i in zip(weights, p_n))
          s = bp * math.exp(math.fsum(s))
      
          return s
      

      我们可以使用原始paper中的一个例子:

      rt_raw = [
      'It is a guide to action that ensures that the military will forever heed Party commands',
      'It is the guiding principle which guarantees the military forces always being under the command of the Party',
      'It is the practical guide for the army always to heed the directions of the party'
      ]
      
      ct_raw = [
      'It is a guide to action which ensures that the military always obeys the commands of the party',
      'It is to insure the troops forever hearing the activity guidebook that party direct'
      ]
      
      def process_trans(t):
          return t.lower().split()
      
      rt = [process_trans(t) for t in rt_raw]
      ct = [process_trans(t) for t in ct_raw]
      
      c1, c2 = ct[0], ct[1]
      
      sentence_bleu_man(rt, c2, weights=(.5, .5, 0, 0))
      sentence_bleu(rt, c2, weights=(.5, .5, 0, 0))
      

      输出:

      0.18174699151949172
      0.18174699151949172
      

      【讨论】:

        【解决方案3】:

        您的 bleu 分数计算错误。 问题:

        • 你必须使用裁剪精度
        • sklearn 对每个 n 克使用权重
        • sklearn 对 n = 1,2,3,4 使用 ngram

        更正的代码

        def bleu_score(original,machine_translated):
            '''
            Bleu score function given a orginal and a machine translated sentences
            '''
            mt_length = len(machine_translated.split())
            o_length = len(original.split())
        
            # Brevity Penalty 
            if mt_length>o_length:
                BP=1
            else:
                penality=1-(mt_length/o_length)
                BP=np.exp(penality)
        
            # Clipped precision
            clipped_precision_score = []
            for i in range(1, 5):
                original_n_gram = Counter(n_gram_generator(original,i))
                machine_n_gram = Counter(n_gram_generator(machine_translated,i))
        
                c = sum(machine_n_gram.values())
                for j in machine_n_gram:
                    if j in original_n_gram:
                        if machine_n_gram[j] > original_n_gram[j]:
                            machine_n_gram[j] = original_n_gram[j]
                    else:
                        machine_n_gram[j] = 0
        
                #print (sum(machine_n_gram.values()), c)
                clipped_precision_score.append(sum(machine_n_gram.values())/c)
        
            #print (clipped_precision_score)
        
            weights =[0.25]*4
        
            s = (w_i * math.log(p_i) for w_i, p_i in zip(weights, clipped_precision_score))
            s = BP * math.exp(math.fsum(s))
            return s
        
        original = "It is a guide to action which ensures that the military alwasy obeys the command of the party"
        machine_translated = "It is the guiding principle which guarantees the military forces alwasy being under the command of the party"
        
        print (bleu_score(original, machine_translated))
        print (sentence_bleu([original.split()], machine_translated.split()))
        

        输出:

        0.27098211583470044
        0.27098211583470044
        

        【讨论】:

          猜你喜欢
          • 2020-08-19
          • 1970-01-01
          • 2021-11-09
          • 1970-01-01
          • 1970-01-01
          • 2019-07-17
          • 2017-09-29
          • 2011-07-04
          • 2021-06-23
          相关资源
          最近更新 更多