【问题标题】:pyspark when otherwise statement returning incorrect outputpyspark when else 语句返回不正确的输出
【发布时间】:2021-10-14 17:41:11
【问题描述】:

我在下面粘贴了我的代码。 我期望当col2 = 7 时,它应该返回1,但它有时会返回 1,有时会返回 2 - 在其他时候。设置后,我不会对 col2 进行任何操作。有没有人经历过这种奇怪的行为?还是因为每个条件的限制是重叠的?

 df = df.withColumn('col1', F.when(F.col('col2').between(1,7), 1)
                             .when(F.col('col2').between(7,14), 2)
                             .when(F.col('col2').between(14,21), 3)
                             .when(F.col('col2').between(21,28), 4)
                             .otherwise(5))

【问题讨论】:

    标签: apache-spark pyspark apache-spark-sql case


    【解决方案1】:

    我想说这是出乎意料的事情,因为 case-when 将被 CodeGen 转换为 ifs 的序列。因此,您应该始终看到 'col2' 为 1。

    您可以查看 Spark 使用 QueryExecution.debug.codegen 生成的实际代码,如下所示:

    >>> df = spark.range(1000)
    >>> from pyspark.sql.functions import *
    >>> dff = df.withColumn('col1',when(col('id').between(1,7),1).when(col('id').between(7,14),2).otherwise(3))
    
    >>> dff._jdf.queryExecution().debug().codegen()
    
    Found 1 WholeStageCodegen subtrees.
    == Subtree 1 / 1 ==
    *(1) Project [id#4L, CASE WHEN ((id#4L >= 1) && (id#4L <= 7)) THEN 1 WHEN ((id#4L >= 7) && (id#4L <= 14)) THEN 2 ELSE 3 END AS col1#6]
    +- *(1) Range (0, 1000, step=1, splits=2)
    
    Generated code:
    /* 001 */ public Object generate(Object[] references) {
    /* 002 */   return new GeneratedIteratorForCodegenStage1(references);
    /* 003 */ }
    /* 004 */
    /* 005 */ // codegenStageId=1
    /* 006 */ final class GeneratedIteratorForCodegenStage1 extends org.apache.spark.sql.execution.BufferedRowIterator {
    /* 007 */   private Object[] references;
    /* 008 */   private scala.collection.Iterator[] inputs;
    /* 009 */   private boolean range_initRange_0;
    /* 010 */   private long range_number_0;
    /* 011 */   private TaskContext range_taskContext_0;
    /* 012 */   private InputMetrics range_inputMetrics_0;
    /* 013 */   private long range_batchEnd_0;
    /* 014 */   private long range_numElementsTodo_0;
    /* 015 */   private int project_project_value_2_0;
    /* 016 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[] range_mutableStateArray_0 = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[3];
    /* 017 */
    /* 018 */   public GeneratedIteratorForCodegenStage1(Object[] references) {
    /* 019 */     this.references = references;
    /* 020 */   }
    /* 021 */
    /* 022 */   public void init(int index, scala.collection.Iterator[] inputs) {
    /* 023 */     partitionIndex = index;
    /* 024 */     this.inputs = inputs;
    /* 025 */
    /* 026 */     range_taskContext_0 = TaskContext.get();
    /* 027 */     range_inputMetrics_0 = range_taskContext_0.taskMetrics().inputMetrics();
    /* 028 */     range_mutableStateArray_0[0] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 0);
    /* 029 */     range_mutableStateArray_0[1] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 0);
    /* 030 */     range_mutableStateArray_0[2] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(2, 0);
    /* 031 */
    /* 032 */   }
    /* 033 */
    /* 034 */   private void project_doConsume_0(long project_expr_0_0) throws java.io.IOException {
    /* 035 */     byte project_caseWhenResultState_0 = -1;
    /* 036 */     do {
    /* 037 */       boolean project_value_4 = false;
    /* 038 */       project_value_4 = project_expr_0_0 >= 1L;
    /* 039 */       boolean project_value_3 = false;
    /* 040 */
    /* 041 */       if (project_value_4) {
    /* 042 */         boolean project_value_7 = false;
    /* 043 */         project_value_7 = project_expr_0_0 <= 7L;
    /* 044 */         project_value_3 = project_value_7;
    /* 045 */       }
    /* 046 */       if (!false && project_value_3) {
    /* 047 */         project_caseWhenResultState_0 = (byte)(false ? 1 : 0);
    /* 048 */         project_project_value_2_0 = 1;
    /* 049 */         continue;
    /* 050 */       }
    /* 051 */
    /* 052 */       boolean project_value_12 = false;
    /* 053 */       project_value_12 = project_expr_0_0 >= 7L;
    /* 054 */       boolean project_value_11 = false;
    /* 055 */
    /* 056 */       if (project_value_12) {
    /* 057 */         boolean project_value_15 = false;
    /* 058 */         project_value_15 = project_expr_0_0 <= 14L;
    /* 059 */         project_value_11 = project_value_15;
    /* 060 */       }
    /* 061 */       if (!false && project_value_11) {
    /* 062 */         project_caseWhenResultState_0 = (byte)(false ? 1 : 0);
    /* 063 */         project_project_value_2_0 = 2;
    /* 064 */         continue;
    /* 065 */       }
    /* 066 */
    /* 067 */       project_caseWhenResultState_0 = (byte)(false ? 1 : 0);
    /* 068 */       project_project_value_2_0 = 3;
    /* 069 */
    /* 070 */     } while (false);
    /* 071 */     // TRUE if any condition is met and the result is null, or no any condition is met.
    /* 072 */     final boolean project_isNull_2 = (project_caseWhenResultState_0 != 0);
    /* 073 */     range_mutableStateArray_0[2].reset();
    /* 074 */
    /* 075 */     range_mutableStateArray_0[2].zeroOutNullBytes();
    /* 076 */
    /* 077 */     range_mutableStateArray_0[2].write(0, project_expr_0_0);
    /* 078 */
    /* 079 */     range_mutableStateArray_0[2].write(1, project_project_value_2_0);
    /* 080 */     append((range_mutableStateArray_0[2].getRow()));
    /* 081 */
    /* 082 */   }
    /* 083 */
    ...
    

    我们对 private void project_doConsume_0(... 方法感兴趣(从第 34 行开始)。

    【讨论】:

    • 感谢您的详细解释。我们现在确实看到了这一点,并对输出进行了更多研究,并意识到 7 位小数是导致此问题的原因。我们现在已经更正了它,代码工作正常。
    【解决方案2】:

    第一点:between 是包容性的,你的区间有一些重叠(7 在第一个和第二个区间都可以是 True,因为它们都包含 7)

    所以这应该改进:

     df = df.withColumn('col1', F.when(F.col('col2').between(1,7), 1)
                                 .when(F.col('col2').between(8,14), 2)
                                 .when(F.col('col2').between(15,21), 3)
                                 .when(F.col('col2').between(22,28), 4)
                                 .otherwise(5))
    

    但在使用多个 F.when() 时,我可以通过将它们嵌套在 .otherwise(F.when()) 中来减少麻烦,如下所示:

     df = df.withColumn('col1', F.when(F.col('col2').between(1,7), 1)
                                 .otherwise(F.when(F.col('col2').between(8,14), 2)
                                 .otherwise(F.when(F.col('col2').between(15,21), 3)
                                 .otherwise(F.when(F.col('col2').between(22,28), 4)
                                 .otherwise(5)))))
    

    【讨论】:

    • 感谢您的回复。我们通过删除 between 和使用 > 和
    猜你喜欢
    • 2018-09-21
    • 2022-11-10
    • 2019-08-16
    • 1970-01-01
    • 2020-12-08
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多