From bad1c59ff63ac99db1a87c394b896c20d7a9949a Mon Sep 17 00:00:00 2001 From: xiedeyantu Date: Mon, 1 Jun 2026 11:55:13 +0800 Subject: [PATCH] Fix log(0.0::float8) should error, not return -inf --- .../core/tests/expr_api/simplification.rs | 31 +++- datafusion/functions/src/math/log.rs | 152 +++++++++++++----- datafusion/sqllogictest/test_files/math.slt | 48 ++++++ datafusion/sqllogictest/test_files/scalar.slt | 9 +- 4 files changed, 188 insertions(+), 52 deletions(-) diff --git a/datafusion/core/tests/expr_api/simplification.rs b/datafusion/core/tests/expr_api/simplification.rs index 245aba66849ce..7671c0a2d7733 100644 --- a/datafusion/core/tests/expr_api/simplification.rs +++ b/datafusion/core/tests/expr_api/simplification.rs @@ -606,24 +606,43 @@ fn test_simplify_with_cycle_count( #[test] fn test_simplify_log() { - // Log(c3, 1) ===> 0 + // Log(10, 1) ===> 0 + { + let expr = log(lit(10), lit(1)); + test_simplify(expr, lit(0)); + } + // Log(10, 10) ===> 1 + { + let expr = log(lit(10), lit(10)); + test_simplify(expr, lit(1)); + } + // Log(c3, 1) ===> Log(c3, 1) { let expr = log(col("c3_non_null"), lit(1)); - test_simplify(expr, lit(0i64)); + test_simplify(expr.clone(), expr); } - // Log(c3, c3) ===> 1 + // Log(10, Power(10, c4)) ===> c4 + { + let expr = log(lit(10), power(lit(10), col("c4_non_null"))); + let expected = col("c4_non_null"); + test_simplify(expr, expected); + } + // Log(c3, c3) ===> Log(c3, c3) { let expr = log(col("c3_non_null"), col("c3_non_null")); - let expected = lit(1i64); + let expected = log(col("c3_non_null"), col("c3_non_null")); test_simplify(expr, expected); } - // Log(c3, Power(c3, c4)) ===> c4 + // Log(c3, Power(c3, c4)) ===> Log(c3, Power(c3, c4)) { let expr = log( col("c3_non_null"), power(col("c3_non_null"), col("c4_non_null")), ); - let expected = col("c4_non_null"); + let expected = log( + col("c3_non_null"), + power(col("c3_non_null"), col("c4_non_null")), + ); test_simplify(expr, expected); } // Log(c3, c4) ===> Log(c3, c4) diff --git a/datafusion/functions/src/math/log.rs b/datafusion/functions/src/math/log.rs index ac94f78e0c723..0dd3d311d0674 100644 --- a/datafusion/functions/src/math/log.rs +++ b/datafusion/functions/src/math/log.rs @@ -106,6 +106,17 @@ fn is_valid_integer_base(base: f64) -> bool { base.trunc() == base && base >= 2.0 && base <= u32::MAX as f64 } +#[inline] +fn validate_log_value(value: f64) -> Result<(), ArrowError> { + if value == 0.0 { + Err(ArrowError::ComputeError( + "cannot take logarithm of zero".to_string(), + )) + } else { + Ok(()) + } +} + /// Calculate logarithm for Decimal32 values. /// For integer bases >= 2 with zero scale, return an exact integer log when the /// value is a perfect power of the base. Otherwise falls back to f64 computation. @@ -121,7 +132,10 @@ fn log_decimal32(value: i32, scale: i8, base: f64) -> Result { return Ok(int_log as f64); } } - decimal_to_f64(value, scale).map(|v| v.log(base)) + decimal_to_f64(value, scale).and_then(|v| { + validate_log_value(v)?; + Ok(v.log(base)) + }) } /// Calculate logarithm for Decimal64 values. @@ -139,7 +153,10 @@ fn log_decimal64(value: i64, scale: i8, base: f64) -> Result { return Ok(int_log as f64); } } - decimal_to_f64(value, scale).map(|v| v.log(base)) + decimal_to_f64(value, scale).and_then(|v| { + validate_log_value(v)?; + Ok(v.log(base)) + }) } /// Calculate logarithm for Decimal128 values. @@ -157,7 +174,20 @@ fn log_decimal128(value: i128, scale: i8, base: f64) -> Result return Ok(int_log as f64); } } - decimal_to_f64(value, scale).map(|v| v.log(base)) + decimal_to_f64(value, scale).and_then(|v| { + validate_log_value(v)?; + Ok(v.log(base)) + }) +} + +/// Compute logarithm for Float16, Float32, and Float64 values +#[inline] +fn compute_float_log(value: T, base: T) -> Result { + let value_f64 = value.to_f64().ok_or_else(|| { + ArrowError::ComputeError("Cannot convert value to f64".to_string()) + })?; + validate_log_value(value_f64)?; + Ok(value.log(base)) } /// Convert a scaled decimal value to f64. @@ -180,7 +210,9 @@ fn log_decimal256(value: i256, scale: i8, base: f64) -> Result ArrowError::ComputeError(format!("Cannot convert {value} to f64")) })?; let scale_factor = 10f64.powi(scale as i32); - Ok((value_f64 / scale_factor).log(base)) + let value = value_f64 / scale_factor; + validate_log_value(value)?; + Ok(value.log(base)) } } } @@ -247,27 +279,24 @@ impl ScalarUDFImpl for LogFunc { let value = value.to_array(args.number_rows)?; let output: ArrayRef = match value.data_type() { - DataType::Float16 => { - calculate_binary_math::( - &value, - &base, - |value, base| Ok(value.log(base)), - )? - } - DataType::Float32 => { - calculate_binary_math::( - &value, - &base, - |value, base| Ok(value.log(base)), - )? - } - DataType::Float64 => { - calculate_binary_math::( - &value, - &base, - |value, base| Ok(value.log(base)), - )? - } + DataType::Float16 => calculate_binary_math::< + Float16Type, + Float16Type, + Float16Type, + _, + >(&value, &base, compute_float_log)?, + DataType::Float32 => calculate_binary_math::< + Float32Type, + Float32Type, + Float32Type, + _, + >(&value, &base, compute_float_log)?, + DataType::Float64 => calculate_binary_math::< + Float64Type, + Float64Type, + Float64Type, + _, + >(&value, &base, compute_float_log)?, DataType::Decimal32(_, scale) => { calculate_binary_math::( &value, @@ -308,10 +337,9 @@ impl ScalarUDFImpl for LogFunc { self.doc() } - /// Simplify the `log` function by the relevant rules: - /// 1. Log(a, 1) ===> 0 - /// 2. Log(a, Power(a, b)) ===> b - /// 3. Log(a, a) ===> 1 + /// Simplify `log` only when the base is a known-valid literal. + /// This preserves current runtime `NaN` / domain behavior for column and + /// expression bases whose validity cannot be proven during planning. fn simplify( &self, mut args: Vec, @@ -358,43 +386,83 @@ impl ScalarUDFImpl for LogFunc { } else { lit(ScalarValue::new_ten(&number_datatype)?) }; + let base_datatype = info.get_data_type(&base)?; + + if is_zero_literal(&number, &number_datatype)? + || is_zero_literal(&base, &base_datatype)? + { + return Ok(ExprSimplifyResult::Original(original_log_args( + num_args, &base, &number, + )?)); + } + + let base_is_valid_literal = is_valid_log_base_literal(&base)?; match number { Expr::Literal(value, _) - if value == ScalarValue::new_one(&number_datatype)? => + if value == ScalarValue::new_one(&number_datatype)? + && base_is_valid_literal => { Ok(ExprSimplifyResult::Simplified(lit(ScalarValue::new_zero( &info.get_data_type(&base)?, )?))) } Expr::ScalarFunction(ScalarFunction { func, mut args }) - if is_pow(&func) && args.len() == 2 && base == args[0] => + if is_pow(&func) + && args.len() == 2 + && base == args[0] + && base_is_valid_literal => { let b = args.pop().unwrap(); // length checked above Ok(ExprSimplifyResult::Simplified(b)) } number => { - if number == base { + if number == base && base_is_valid_literal { Ok(ExprSimplifyResult::Simplified(lit(ScalarValue::new_one( &number_datatype, )?))) } else { - let args = match num_args { - 1 => vec![number], - 2 => vec![base, number], - _ => { - return internal_err!( - "Unexpected number of arguments in log::simplify" - ); - } - }; - Ok(ExprSimplifyResult::Original(args)) + Ok(ExprSimplifyResult::Original(original_log_args( + num_args, &base, &number, + )?)) } } } } } +#[inline] +fn original_log_args(num_args: usize, base: &Expr, number: &Expr) -> Result> { + match num_args { + 1 => Ok(vec![number.clone()]), + 2 => Ok(vec![base.clone(), number.clone()]), + _ => { + internal_err!("Unexpected number of arguments in log::simplify") + } + } +} + +#[inline] +fn is_zero_literal(expr: &Expr, data_type: &DataType) -> Result { + match expr { + Expr::Literal(value, _) => Ok(*value == ScalarValue::new_zero(data_type)?), + _ => Ok(false), + } +} + +#[inline] +fn is_valid_log_base_literal(expr: &Expr) -> Result { + match expr { + Expr::Literal(value, _) => { + let scalar = value.cast_to(&DataType::Float64)?; + Ok( + matches!(scalar, ScalarValue::Float64(Some(base)) if base > 0.0 && base != 1.0), + ) + } + _ => Ok(false), + } +} + /// Returns true if the function is `PowerFunc` fn is_pow(func: &ScalarUDF) -> bool { func.inner().is::() diff --git a/datafusion/sqllogictest/test_files/math.slt b/datafusion/sqllogictest/test_files/math.slt index 475434883d315..dc1be507a6a86 100644 --- a/datafusion/sqllogictest/test_files/math.slt +++ b/datafusion/sqllogictest/test_files/math.slt @@ -963,3 +963,51 @@ SELECT gcd(column1, 0) FROM (VALUES (7), (-3), (0)); 7 3 0 + +# Verify error handling for log with zero values +query error DataFusion error: Arrow error: Compute error: cannot take logarithm of zero +SELECT log(0, 0); + +query error DataFusion error: Arrow error: Compute error: cannot take logarithm of zero +SELECT log(0); + +query error DataFusion error: Arrow error: Compute error: cannot take logarithm of zero +SELECT log(2, 0); + +# Safe literal-base rewrites must preserve current results. +query BBBB +SELECT log(10, 1) = 0, log(10, 10) = 1, log(10) = 1, log(10, power(10, 2)) = 2; +---- +true true true true + +query B rowsort +SELECT log(10, power(10, column1)) = column1 +FROM (VALUES (2), (3)) AS t(column1); +---- +true +true + +# Invalid literal bases must keep runtime NaN behavior rather than being simplified. +query BBB +SELECT isnan(log(1, 1)), isnan(log(-2, 1)), isnan(log(1, power(1, 2))); +---- +true true true + +query B rowsort +SELECT isnan(log(1, power(1, column1))) +FROM (VALUES (2), (3)) AS t(column1); +---- +true +true + +# log(col, power(col, b)) must NOT be simplified away when col may be zero at runtime. +# Before the fix the planner rewrote log(a, power(a, b)) => b for any expression a, +# so the row where column1=0 silently returned 2.0 instead of raising an error. +query error DataFusion error: Arrow error: Compute error: cannot take logarithm of zero +SELECT log(column1, power(column1, 2)) +FROM (VALUES (0.0), (2.0)) AS t(column1); + +# log(col, col) must also preserve runtime validation for zero rows. +query error DataFusion error: Arrow error: Compute error: cannot take logarithm of zero +SELECT log(column1, column1) +FROM (VALUES (0.0), (2.0)) AS t(column1); diff --git a/datafusion/sqllogictest/test_files/scalar.slt b/datafusion/sqllogictest/test_files/scalar.slt index 38f76f13151bc..cd54fd29c0cef 100644 --- a/datafusion/sqllogictest/test_files/scalar.slt +++ b/datafusion/sqllogictest/test_files/scalar.slt @@ -625,10 +625,11 @@ select log(2, 2.0/3) a, log(10, 2.0/3) b; # log scalar ops with zero edgecases # please see https://github.com/apache/datafusion/pull/5245#issuecomment-1426828382 -query RR rowsort -select log(0) a, log(1, 64) b; ----- --Infinity Infinity +query error cannot take logarithm of zero +select log(0) a; + +query error cannot take logarithm of zero +select log(0, power(0, 2)) a; # log with columns #1 query RRR rowsort