Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 25 additions & 6 deletions datafusion/core/tests/expr_api/simplification.rs
Original file line number Diff line number Diff line change
Expand Up @@ -606,24 +606,43 @@ fn test_simplify_with_cycle_count(

#[test]
fn test_simplify_log() {
// Log(c3, 1) ===> 0
// Log(10, 1) ===> 0
{
let expr = log(lit(10), lit(1));
test_simplify(expr, lit(0));
}
// Log(10, 10) ===> 1
{
let expr = log(lit(10), lit(10));
test_simplify(expr, lit(1));
}
// Log(c3, 1) ===> Log(c3, 1)
{
let expr = log(col("c3_non_null"), lit(1));
test_simplify(expr, lit(0i64));
test_simplify(expr.clone(), expr);
}
// Log(c3, c3) ===> 1
// Log(10, Power(10, c4)) ===> c4
{
let expr = log(lit(10), power(lit(10), col("c4_non_null")));
let expected = col("c4_non_null");
test_simplify(expr, expected);
}
// Log(c3, c3) ===> Log(c3, c3)
{
let expr = log(col("c3_non_null"), col("c3_non_null"));
let expected = lit(1i64);
let expected = log(col("c3_non_null"), col("c3_non_null"));
test_simplify(expr, expected);
}
// Log(c3, Power(c3, c4)) ===> c4
// Log(c3, Power(c3, c4)) ===> Log(c3, Power(c3, c4))
{
let expr = log(
col("c3_non_null"),
power(col("c3_non_null"), col("c4_non_null")),
);
let expected = col("c4_non_null");
let expected = log(
col("c3_non_null"),
power(col("c3_non_null"), col("c4_non_null")),
);
test_simplify(expr, expected);
}
// Log(c3, c4) ===> Log(c3, c4)
Expand Down
152 changes: 110 additions & 42 deletions datafusion/functions/src/math/log.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,17 @@ fn is_valid_integer_base(base: f64) -> bool {
base.trunc() == base && base >= 2.0 && base <= u32::MAX as f64
}

#[inline]
fn validate_log_value(value: f64) -> Result<(), ArrowError> {
if value == 0.0 {
Err(ArrowError::ComputeError(
"cannot take logarithm of zero".to_string(),
))
} else {
Ok(())
}
}

/// Calculate logarithm for Decimal32 values.
/// For integer bases >= 2 with zero scale, return an exact integer log when the
/// value is a perfect power of the base. Otherwise falls back to f64 computation.
Expand All @@ -121,7 +132,10 @@ fn log_decimal32(value: i32, scale: i8, base: f64) -> Result<f64, ArrowError> {
return Ok(int_log as f64);
}
}
decimal_to_f64(value, scale).map(|v| v.log(base))
decimal_to_f64(value, scale).and_then(|v| {
validate_log_value(v)?;
Ok(v.log(base))
})
}

/// Calculate logarithm for Decimal64 values.
Expand All @@ -139,7 +153,10 @@ fn log_decimal64(value: i64, scale: i8, base: f64) -> Result<f64, ArrowError> {
return Ok(int_log as f64);
}
}
decimal_to_f64(value, scale).map(|v| v.log(base))
decimal_to_f64(value, scale).and_then(|v| {
validate_log_value(v)?;
Ok(v.log(base))
})
}

/// Calculate logarithm for Decimal128 values.
Expand All @@ -157,7 +174,20 @@ fn log_decimal128(value: i128, scale: i8, base: f64) -> Result<f64, ArrowError>
return Ok(int_log as f64);
}
}
decimal_to_f64(value, scale).map(|v| v.log(base))
decimal_to_f64(value, scale).and_then(|v| {
validate_log_value(v)?;
Ok(v.log(base))
})
}

/// Compute logarithm for Float16, Float32, and Float64 values
#[inline]
fn compute_float_log<T: Float + ToPrimitive>(value: T, base: T) -> Result<T, ArrowError> {
let value_f64 = value.to_f64().ok_or_else(|| {
ArrowError::ComputeError("Cannot convert value to f64".to_string())
})?;
validate_log_value(value_f64)?;
Ok(value.log(base))
}

/// Convert a scaled decimal value to f64.
Expand All @@ -180,7 +210,9 @@ fn log_decimal256(value: i256, scale: i8, base: f64) -> Result<f64, ArrowError>
ArrowError::ComputeError(format!("Cannot convert {value} to f64"))
})?;
let scale_factor = 10f64.powi(scale as i32);
Ok((value_f64 / scale_factor).log(base))
let value = value_f64 / scale_factor;
validate_log_value(value)?;
Ok(value.log(base))
}
}
}
Expand Down Expand Up @@ -247,27 +279,24 @@ impl ScalarUDFImpl for LogFunc {
let value = value.to_array(args.number_rows)?;

let output: ArrayRef = match value.data_type() {
DataType::Float16 => {
calculate_binary_math::<Float16Type, Float16Type, Float16Type, _>(
&value,
&base,
|value, base| Ok(value.log(base)),
)?
}
DataType::Float32 => {
calculate_binary_math::<Float32Type, Float32Type, Float32Type, _>(
&value,
&base,
|value, base| Ok(value.log(base)),
)?
}
DataType::Float64 => {
calculate_binary_math::<Float64Type, Float64Type, Float64Type, _>(
&value,
&base,
|value, base| Ok(value.log(base)),
)?
}
DataType::Float16 => calculate_binary_math::<
Comment thread
xiedeyantu marked this conversation as resolved.
Float16Type,
Float16Type,
Float16Type,
_,
>(&value, &base, compute_float_log)?,
DataType::Float32 => calculate_binary_math::<
Float32Type,
Float32Type,
Float32Type,
_,
>(&value, &base, compute_float_log)?,
DataType::Float64 => calculate_binary_math::<
Float64Type,
Float64Type,
Float64Type,
_,
>(&value, &base, compute_float_log)?,
DataType::Decimal32(_, scale) => {
calculate_binary_math::<Decimal32Type, Float64Type, Float64Type, _>(
&value,
Expand Down Expand Up @@ -308,10 +337,9 @@ impl ScalarUDFImpl for LogFunc {
self.doc()
}

/// Simplify the `log` function by the relevant rules:
/// 1. Log(a, 1) ===> 0
/// 2. Log(a, Power(a, b)) ===> b
/// 3. Log(a, a) ===> 1
/// Simplify `log` only when the base is a known-valid literal.
/// This preserves current runtime `NaN` / domain behavior for column and
/// expression bases whose validity cannot be proven during planning.
fn simplify(
&self,
mut args: Vec<Expr>,
Expand Down Expand Up @@ -358,43 +386,83 @@ impl ScalarUDFImpl for LogFunc {
} else {
lit(ScalarValue::new_ten(&number_datatype)?)
};
let base_datatype = info.get_data_type(&base)?;

if is_zero_literal(&number, &number_datatype)?
|| is_zero_literal(&base, &base_datatype)?
{
return Ok(ExprSimplifyResult::Original(original_log_args(
num_args, &base, &number,
)?));
}

let base_is_valid_literal = is_valid_log_base_literal(&base)?;

match number {
Expr::Literal(value, _)
if value == ScalarValue::new_one(&number_datatype)? =>
if value == ScalarValue::new_one(&number_datatype)?
&& base_is_valid_literal =>
{
Ok(ExprSimplifyResult::Simplified(lit(ScalarValue::new_zero(
&info.get_data_type(&base)?,
)?)))
}
Expr::ScalarFunction(ScalarFunction { func, mut args })
if is_pow(&func) && args.len() == 2 && base == args[0] =>
if is_pow(&func)
&& args.len() == 2
&& base == args[0]
&& base_is_valid_literal =>
{
let b = args.pop().unwrap(); // length checked above
Ok(ExprSimplifyResult::Simplified(b))
}
number => {
if number == base {
if number == base && base_is_valid_literal {
Ok(ExprSimplifyResult::Simplified(lit(ScalarValue::new_one(
&number_datatype,
)?)))
} else {
let args = match num_args {
1 => vec![number],
2 => vec![base, number],
_ => {
return internal_err!(
"Unexpected number of arguments in log::simplify"
);
}
};
Ok(ExprSimplifyResult::Original(args))
Ok(ExprSimplifyResult::Original(original_log_args(
num_args, &base, &number,
)?))
}
}
}
}
}

#[inline]
fn original_log_args(num_args: usize, base: &Expr, number: &Expr) -> Result<Vec<Expr>> {
match num_args {
1 => Ok(vec![number.clone()]),
2 => Ok(vec![base.clone(), number.clone()]),
_ => {
internal_err!("Unexpected number of arguments in log::simplify")
}
}
}

#[inline]
fn is_zero_literal(expr: &Expr, data_type: &DataType) -> Result<bool> {
match expr {
Expr::Literal(value, _) => Ok(*value == ScalarValue::new_zero(data_type)?),
_ => Ok(false),
}
}

#[inline]
fn is_valid_log_base_literal(expr: &Expr) -> Result<bool> {
match expr {
Expr::Literal(value, _) => {
let scalar = value.cast_to(&DataType::Float64)?;
Ok(
matches!(scalar, ScalarValue::Float64(Some(base)) if base > 0.0 && base != 1.0),
)
}
_ => Ok(false),
}
}

/// Returns true if the function is `PowerFunc`
fn is_pow(func: &ScalarUDF) -> bool {
func.inner().is::<PowerFunc>()
Expand Down
48 changes: 48 additions & 0 deletions datafusion/sqllogictest/test_files/math.slt
Original file line number Diff line number Diff line change
Expand Up @@ -963,3 +963,51 @@ SELECT gcd(column1, 0) FROM (VALUES (7), (-3), (0));
7
3
0

# Verify error handling for log with zero values
query error DataFusion error: Arrow error: Compute error: cannot take logarithm of zero
SELECT log(0, 0);

query error DataFusion error: Arrow error: Compute error: cannot take logarithm of zero
SELECT log(0);

query error DataFusion error: Arrow error: Compute error: cannot take logarithm of zero
SELECT log(2, 0);

# Safe literal-base rewrites must preserve current results.
query BBBB
SELECT log(10, 1) = 0, log(10, 10) = 1, log(10) = 1, log(10, power(10, 2)) = 2;
----
true true true true

query B rowsort
SELECT log(10, power(10, column1)) = column1
FROM (VALUES (2), (3)) AS t(column1);
----
true
true

# Invalid literal bases must keep runtime NaN behavior rather than being simplified.
query BBB
SELECT isnan(log(1, 1)), isnan(log(-2, 1)), isnan(log(1, power(1, 2)));
----
true true true

query B rowsort
SELECT isnan(log(1, power(1, column1)))
FROM (VALUES (2), (3)) AS t(column1);
----
true
true

# log(col, power(col, b)) must NOT be simplified away when col may be zero at runtime.
# Before the fix the planner rewrote log(a, power(a, b)) => b for any expression a,
# so the row where column1=0 silently returned 2.0 instead of raising an error.
query error DataFusion error: Arrow error: Compute error: cannot take logarithm of zero
SELECT log(column1, power(column1, 2))
FROM (VALUES (0.0), (2.0)) AS t(column1);

# log(col, col) must also preserve runtime validation for zero rows.
query error DataFusion error: Arrow error: Compute error: cannot take logarithm of zero
SELECT log(column1, column1)
FROM (VALUES (0.0), (2.0)) AS t(column1);
9 changes: 5 additions & 4 deletions datafusion/sqllogictest/test_files/scalar.slt
Original file line number Diff line number Diff line change
Expand Up @@ -625,10 +625,11 @@ select log(2, 2.0/3) a, log(10, 2.0/3) b;

# log scalar ops with zero edgecases
# please see https://github.com/apache/datafusion/pull/5245#issuecomment-1426828382
query RR rowsort
select log(0) a, log(1, 64) b;
----
-Infinity Infinity
query error cannot take logarithm of zero
select log(0) a;

query error cannot take logarithm of zero
select log(0, power(0, 2)) a;

# log with columns #1
query RRR rowsort
Expand Down
Loading