diff --git a/datafusion/functions/src/math/factorial.rs b/datafusion/functions/src/math/factorial.rs index 3b4f973f19d62..5f9416c2ebc6c 100644 --- a/datafusion/functions/src/math/factorial.rs +++ b/datafusion/functions/src/math/factorial.rs @@ -15,11 +15,12 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::{ArrayRef, AsArray, Int64Array}; +use arrow::array::{ArrayRef, AsArray}; use std::sync::Arc; -use arrow::datatypes::DataType::Int64; -use arrow::datatypes::{DataType, Int64Type}; +use arrow::datatypes::DataType::{Decimal256, Int64}; +use arrow::datatypes::{DECIMAL256_MAX_PRECISION, DataType, Decimal256Type, Int64Type}; +use arrow_buffer::i256; use datafusion_common::{ Result, ScalarValue, exec_err, internal_err, utils::take_function_args, @@ -63,6 +64,8 @@ impl FactorialFunc { } } +const FACTORIAL_RETURN_TYPE: DataType = Decimal256(DECIMAL256_MAX_PRECISION, 0); + impl ScalarUDFImpl for FactorialFunc { fn name(&self) -> &str { "factorial" @@ -73,7 +76,7 @@ impl ScalarUDFImpl for FactorialFunc { } fn return_type(&self, _arg_types: &[DataType]) -> Result { - Ok(Int64) + Ok(FACTORIAL_RETURN_TYPE) } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { @@ -82,13 +85,21 @@ impl ScalarUDFImpl for FactorialFunc { match arg { ColumnarValue::Scalar(scalar) => { if scalar.is_null() { - return Ok(ColumnarValue::Scalar(ScalarValue::Int64(None))); + return Ok(ColumnarValue::Scalar(ScalarValue::Decimal256( + None, + DECIMAL256_MAX_PRECISION, + 0, + ))); } match scalar { ScalarValue::Int64(Some(v)) => { let result = compute_factorial(v)?; - Ok(ColumnarValue::Scalar(ScalarValue::Int64(Some(result)))) + Ok(ColumnarValue::Scalar(ScalarValue::Decimal256( + Some(result), + DECIMAL256_MAX_PRECISION, + 0, + ))) } _ => { internal_err!( @@ -100,9 +111,10 @@ impl ScalarUDFImpl for FactorialFunc { } ColumnarValue::Array(array) => match array.data_type() { Int64 => { - let result: Int64Array = array + let result = array .as_primitive::() - .try_unary(compute_factorial)?; + .try_unary::<_, Decimal256Type, _>(compute_factorial)? + .with_precision_and_scale(DECIMAL256_MAX_PRECISION, 0)?; Ok(ColumnarValue::Array(Arc::new(result) as ArrayRef)) } other => { @@ -117,36 +129,107 @@ impl ScalarUDFImpl for FactorialFunc { } } -const FACTORIALS: [i64; 21] = [ - 1, - 1, - 2, - 6, - 24, - 120, - 720, - 5040, - 40320, - 362880, - 3628800, - 39916800, - 479001600, - 6227020800, - 87178291200, - 1307674368000, - 20922789888000, - 355687428096000, - 6402373705728000, - 121645100408832000, - 2432902008176640000, -]; // if return type changes, this constant needs to be updated accordingly - -fn compute_factorial(n: i64) -> Result { +const FACTORIALS: [i256; 57] = [ + i256::from_parts(1, 0), + i256::from_parts(1, 0), + i256::from_parts(2, 0), + i256::from_parts(6, 0), + i256::from_parts(24, 0), + i256::from_parts(120, 0), + i256::from_parts(720, 0), + i256::from_parts(5040, 0), + i256::from_parts(40320, 0), + i256::from_parts(362880, 0), + i256::from_parts(3628800, 0), + i256::from_parts(39916800, 0), + i256::from_parts(479001600, 0), + i256::from_parts(6227020800, 0), + i256::from_parts(87178291200, 0), + i256::from_parts(1307674368000, 0), + i256::from_parts(20922789888000, 0), + i256::from_parts(355687428096000, 0), + i256::from_parts(6402373705728000, 0), + i256::from_parts(121645100408832000, 0), + i256::from_parts(2432902008176640000, 0), + i256::from_parts(51090942171709440000, 0), + i256::from_parts(1124000727777607680000, 0), + i256::from_parts(25852016738884976640000, 0), + i256::from_parts(620448401733239439360000, 0), + i256::from_parts(15511210043330985984000000, 0), + i256::from_parts(403291461126605635584000000, 0), + i256::from_parts(10888869450418352160768000000, 0), + i256::from_parts(304888344611713860501504000000, 0), + i256::from_parts(8841761993739701954543616000000, 0), + i256::from_parts(265252859812191058636308480000000, 0), + i256::from_parts(8222838654177922817725562880000000, 0), + i256::from_parts(263130836933693530167218012160000000, 0), + i256::from_parts(8683317618811886495518194401280000000, 0), + i256::from_parts(295232799039604140847618609643520000000, 0), + i256::from_parts(124676958757991025765413114570153656320, 30), + i256::from_parts(64699745315476902531002227912544878592, 1093), + i256::from_parts(11914008226076149403460180741783027712, 40448), + i256::from_parts(112449945669955213868112260755986841600, 1537025), + i256::from_parts(302159478076991779295882880302268284928, 59943987), + i256::from_parts(176496280846824950617203951978843996160, 2397759515), + i256::from_parts(90417809380115242574495275065471401984, 98308140136), + i256::from_parts(54441957834517090031680871000348557312, 4128941885723), + i256::from_parts(299309985358604090582029808424378695680, 177544501086095), + i256::from_parts(238909412782918374001076488265470574592, 7811958047788218), + i256::from_parts(202170200682234462683829141561361301504, 351538112150469841), + i256::from_parts( + 112205324517446769945026111164878159872, + 16170753158921612713, + ), + i256::from_parts( + 169414748505921235465608113272750342144, + 760025398469315797526, + ), + i256::from_parts( + 305413489102634642691573466161347559424, + 36481219126527158281271, + ), + i256::from_parts( + 333119188428743562961991722339997319168, + 1787579737199830755782322, + ), + i256::from_parts( + 322405809232131901857604960274991808512, + 89378986859991537789116148, + ), + i256::from_parts( + 109142658633680748495871817299708084224, + 4558328329859568427244923596, + ), + i256::from_parts( + 230900378216383506371340780676528996352, + 237033073152697558216736027008, + ), + i256::from_parts( + 327837203235479616462950115744149405696, + 12562752877092970585487009431459, + ), + i256::from_parts( + 8525894827099188903826663732120911872, + 678388655363020411616298509298838, + ), + i256::from_parts( + 128641848569516926247091897834881941504, + 37311376044966122638896418011436091, + ), + i256::from_parts( + 58013814553240137106279522686256283648, + 2089437058518102867778199408640421117, + ), +]; + +fn compute_factorial(n: i64) -> Result { if n < 0 { - exec_err!("factorial of a negative number is undefined") - } else if n < FACTORIALS.len() as i64 { - Ok(FACTORIALS[n as usize]) - } else { - exec_err!("Overflow happened on FACTORIAL({n})") + return exec_err!("factorial of a negative number is undefined"); } + + if let Some(value) = FACTORIALS.get(n as usize) { + return Ok(*value); + } + + exec_err!("Overflow happened on FACTORIAL({n})") } diff --git a/datafusion/sqllogictest/test_files/scalar.slt b/datafusion/sqllogictest/test_files/scalar.slt index 38f76f13151bc..9e8bea05a412e 100644 --- a/datafusion/sqllogictest/test_files/scalar.slt +++ b/datafusion/sqllogictest/test_files/scalar.slt @@ -461,19 +461,37 @@ select round(exp(a), 5), round(exp(e), 5), round(exp(f), 5) from signed_integers ## factorial # factorial scalar function -query III rowsort +query RRR rowsort select factorial(0), factorial(10), factorial(15); ---- 1 3628800 1307674368000 +query TR +select arrow_typeof(factorial(21)), factorial(21); +---- +Decimal256(76, 0) 51090942171709440000 + +query R +select factorial(21); +---- +51090942171709440000 + +query TR +select arrow_typeof(factorial(56)), factorial(56); +---- +Decimal256(76, 0) 710998587804863451854045647463724949736497978881168458687447040000000000000 + +query error DataFusion error: Execution error: Overflow happened on FACTORIAL\(57\) +select factorial(57); + # factorial scalar nulls -query I rowsort +query R rowsort select factorial(null); ---- NULL # factorial with columns -query III rowsort +query RRR rowsort select factorial(a), factorial(e), factorial(f) from unsigned_integers; ---- 1 24 3628800