Skip to content
257 changes: 244 additions & 13 deletions datafusion/expr/src/logical_plan/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -708,20 +708,76 @@ impl LogicalPlan {
}))
}
LogicalPlan::Union(Union { inputs, schema }) => {
let first_input_schema = inputs[0].schema();
if schema.fields().len() == first_input_schema.fields().len() {
// If inputs are not pruned do not change schema
Ok(LogicalPlan::Union(Union { inputs, schema }))
} else {
// A note on `Union`s constructed via `try_new_by_name`:
//
// At this point, the schema for each input should have
// the same width. Thus, we do not need to save whether a
// `Union` was created `BY NAME`, and can safely rely on the
// `try_new` initializer to derive the new schema based on
// column positions.
Ok(LogicalPlan::Union(Union::try_new(inputs)?))
// Fast path: if all inputs structurally match the cached schema
// (field count, types, names, qualifiers, nullability) then no
// recomputation is needed and we avoid any allocation.
let schemas_match = inputs.iter().all(|input| {
let input_schema = input.schema();
schema.fields().len() == input_schema.fields().len()
&& schema.iter().zip(input_schema.iter()).all(
|((q1, f1), (q2, f2))| {
q1 == q2
&& f1.name() == f2.name()
&& f1.data_type() == f2.data_type()
&& f1.is_nullable() == f2.is_nullable()
},
)
});
if schemas_match {
// Inputs are structurally identical to the cached schema.
return Ok(LogicalPlan::Union(Union { inputs, schema }));
}

// Slow path: inputs changed — recompute the schema.
//
// NOTE: A note on `Union`s constructed via `try_new_by_name`:
// At this point, the schema for each input should have
// the same width. Thus, we do not need to save whether a
// `Union` was created `BY NAME`, and can safely rely on the
// `try_new` initializer to derive the new schema based on
// column positions.
let mut recomputed = Union::try_new(inputs)?;

// Metadata preservation: Union::try_new uses intersection logic
// for metadata, but we want "later takes precedence" (extend semantics)
// to match coerce_union_schema_with_schema in type_coercion.rs.
let mut merged_metadata =
recomputed.inputs[0].schema().metadata().clone();
for input in recomputed.inputs.iter().skip(1) {
merged_metadata.extend(input.schema().metadata().clone());
}

let mut merged_field_metadata = recomputed.inputs[0]
.schema()
.fields()
.iter()
.map(|f| f.metadata().clone())
.collect::<Vec<_>>();

for input in recomputed.inputs.iter().skip(1) {
for (field_meta, input_field) in merged_field_metadata
.iter_mut()
.zip(input.schema().fields())
{
field_meta.extend(input_field.metadata().clone());
}
}

let new_fields = recomputed
.schema
.iter()
.zip(merged_field_metadata)
.map(|((qualifier, field), meta)| {
let mut field = field.as_ref().clone();
field.set_metadata(meta);
(qualifier.cloned(), Arc::new(field))
})
.collect::<Vec<_>>();

recomputed.schema =
Arc::new(DFSchema::new_with_metadata(new_fields, merged_metadata)?);

Ok(LogicalPlan::Union(recomputed))
}
LogicalPlan::Distinct(distinct) => {
let distinct = match distinct {
Expand Down Expand Up @@ -6128,4 +6184,179 @@ mod tests {

Ok(())
}

#[test]
fn test_recompute_schema_union_type_mismatch() -> Result<()> {
use arrow::datatypes::{DataType, Field, Schema};

let schema_i32 = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
let schema_i64 = Schema::new(vec![Field::new("a", DataType::Int64, false)]);

// Build a Union whose schema starts out as Int32 (matching its inputs).
let original = Union::try_new(vec![
Arc::new(table_scan(Some("t1"), &schema_i32, None)?.build()?),
Arc::new(table_scan(Some("t2"), &schema_i32, None)?.build()?),
])?;
assert_eq!(
original.schema.field(0).data_type(),
&DataType::Int32,
"sanity: starting schema is Int32"
);

// Simulate a rewrite pass (e.g. type-coercion) that replaced the inputs
// with Int64-typed versions while leaving the Union's cached schema stale.
// Same width, different types — this is exactly the bug scenario.
let stale = LogicalPlan::Union(Union {
inputs: vec![
Arc::new(table_scan(Some("t1"), &schema_i64, None)?.build()?),
Arc::new(table_scan(Some("t2"), &schema_i64, None)?.build()?),
],
schema: Arc::clone(&original.schema),
});

let recomputed = stale.recompute_schema()?;

assert_eq!(
recomputed.schema().field(0).data_type(),
&DataType::Int64,
"Union schema should track the new Int64 input types after \
recompute_schema(), but the width-only check left it stale"
);

Ok(())
}

#[test]
fn test_recompute_schema_union_name_mismatch() -> Result<()> {
use arrow::datatypes::{DataType, Field, Schema};

let schema_a = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
let schema_b = Schema::new(vec![Field::new("b", DataType::Int32, false)]);

// Build a Union whose schema starts out with column "a".
let original = Union::try_new(vec![
Arc::new(table_scan(Some("t1"), &schema_a, None)?.build()?),
Arc::new(table_scan(Some("t2"), &schema_a, None)?.build()?),
])?;
assert_eq!(
original.schema.field(0).name(),
"a",
"sanity: starting schema has column name 'a'"
);

// Simulate a rewrite pass that renamed the columns but left
// the cached schema stale. Same width and type, different name.
let stale = LogicalPlan::Union(Union {
inputs: vec![
Arc::new(table_scan(Some("t1"), &schema_b, None)?.build()?),
Arc::new(table_scan(Some("t2"), &schema_b, None)?.build()?),
],
schema: Arc::clone(&original.schema),
});

let recomputed = stale.recompute_schema()?;

assert_eq!(
recomputed.schema().field(0).name(),
"b",
"Union schema should reflect the renamed column after \
recompute_schema(), but the width-only check left it stale"
);

Ok(())
}

#[test]
fn test_recompute_schema_union_nullability_mismatch() -> Result<()> {
use arrow::datatypes::{DataType, Field, Schema};

// nullable: false
let schema_not_null = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
// nullable: true
let schema_nullable = Schema::new(vec![Field::new("a", DataType::Int32, true)]);

// Build Union starting with NOT NULL inputs.
let original = Union::try_new(vec![
Arc::new(table_scan(Some("t1"), &schema_not_null, None)?.build()?),
Arc::new(table_scan(Some("t2"), &schema_not_null, None)?.build()?),
])?;
assert!(
!original.schema.field(0).is_nullable(),
"sanity: starting schema field is NOT NULL"
);

// Simulate a rewrite that made the inputs nullable while leaving
// the Union's cached schema stale.
let stale = LogicalPlan::Union(Union {
inputs: vec![
Arc::new(table_scan(Some("t1"), &schema_nullable, None)?.build()?),
Arc::new(table_scan(Some("t2"), &schema_nullable, None)?.build()?),
],
schema: Arc::clone(&original.schema),
});

let recomputed = stale.recompute_schema()?;

assert!(
recomputed.schema().field(0).is_nullable(),
"Union schema should reflect the new nullable inputs after \
recompute_schema(), but the stale NOT NULL schema was kept"
);

Ok(())
}

#[test]
fn test_recompute_schema_union_metadata_preservation() -> Result<()> {
use arrow::datatypes::{DataType, Field, Schema};
use std::collections::HashMap;

let mut meta1 = HashMap::new();
meta1.insert("k1".to_string(), "v1".to_string());
let mut meta2 = HashMap::new();
meta2.insert("k1".to_string(), "v2".to_string()); // duplicate key, different value
meta2.insert("k2".to_string(), "v2".to_string());

let schema1 = Schema::new_with_metadata(
vec![Field::new("a", DataType::Int32, false)],
meta1.clone(),
);
let schema2 = Schema::new_with_metadata(
vec![Field::new("a", DataType::Int32, false)],
meta2.clone(),
);

// Build a Union. Its initial schema will have intersected metadata.
let original = Union::try_new(vec![
Arc::new(table_scan(Some("t1"), &schema1, None)?.build()?),
Arc::new(table_scan(Some("t2"), &schema2, None)?.build()?),
])?;

// Union::try_new uses intersection, so k1 should be missing (v1 != v2)
// and k2 should be missing (not in meta1).
assert!(original.schema.metadata().is_empty());

// Now simulate recompute_schema() where we want EXTEND semantics (later takes precedence).
// Our implementation of recompute_schema for Union now does this.
let stale = LogicalPlan::Union(Union {
inputs: vec![
Arc::new(table_scan(Some("t1"), &schema1, None)?.build()?),
Arc::new(table_scan(Some("t2"), &schema2, None)?.build()?),
],
// Use a dummy schema that forces recomputation (e.g. different name)
schema: Arc::new(DFSchema::try_from(Schema::new(vec![Field::new(
"wrong_name",
DataType::Int32,
false,
)]))?),
});

let recomputed = stale.recompute_schema()?;

// Metadata should now be {k1: v2, k2: v2} because meta2 was the last input.
assert_eq!(recomputed.schema().metadata().get("k1").unwrap(), "v2");
assert_eq!(recomputed.schema().metadata().get("k2").unwrap(), "v2");

Ok(())
}
}
Loading