diff --git a/datafusion/substrait/src/logical_plan/consumer/types.rs b/datafusion/substrait/src/logical_plan/consumer/types.rs index 2493ac1e5ad57..6ca9ced2fcb06 100644 --- a/datafusion/substrait/src/logical_plan/consumer/types.rs +++ b/datafusion/substrait/src/logical_plan/consumer/types.rs @@ -347,12 +347,98 @@ fn from_substrait_struct_type( ) -> datafusion::common::Result { let mut fields = vec![]; for (i, f) in s.types.iter().enumerate() { - let field = Field::new( - next_struct_field_name(i, dfs_names, name_idx)?, - from_substrait_type(consumer, f, dfs_names, name_idx)?, - true, // We assume everything to be nullable since that's easier than ensuring it matches - ); + let name = next_struct_field_name(i, dfs_names, name_idx)?; + let data_type = from_substrait_type(consumer, f, dfs_names, name_idx)?; + let field = Field::new(name, data_type, type_is_nullable(f)?); fields.push(field); } Ok(fields.into()) } + +fn type_is_nullable(dt: &Type) -> datafusion::common::Result { + let Some(kind) = dt.kind.as_ref() else { + return Ok(true); + }; + + let nullability = match kind { + r#type::Kind::Bool(boolean) => boolean.nullability, + r#type::Kind::I8(integer) => integer.nullability, + r#type::Kind::I16(integer) => integer.nullability, + r#type::Kind::I32(integer) => integer.nullability, + r#type::Kind::I64(integer) => integer.nullability, + r#type::Kind::Fp32(float) => float.nullability, + r#type::Kind::Fp64(float) => float.nullability, + #[expect(deprecated)] + r#type::Kind::Timestamp(timestamp) => timestamp.nullability, + r#type::Kind::Date(date) => date.nullability, + #[expect(deprecated)] + r#type::Kind::Time(time) => time.nullability, + #[expect(deprecated)] + r#type::Kind::TimestampTz(timestamp) => timestamp.nullability, + r#type::Kind::IntervalYear(interval) => interval.nullability, + r#type::Kind::IntervalDay(interval) => interval.nullability, + r#type::Kind::IntervalCompound(interval) => interval.nullability, + r#type::Kind::Uuid(uuid) => uuid.nullability, + r#type::Kind::String(string) => string.nullability, + r#type::Kind::Binary(binary) => binary.nullability, + r#type::Kind::FixedChar(fixed) => fixed.nullability, + r#type::Kind::Varchar(varchar) => varchar.nullability, + r#type::Kind::FixedBinary(fixed) => fixed.nullability, + r#type::Kind::Decimal(decimal) => decimal.nullability, + r#type::Kind::PrecisionTime(time) => time.nullability, + r#type::Kind::PrecisionTimestamp(timestamp) => timestamp.nullability, + r#type::Kind::PrecisionTimestampTz(timestamp) => timestamp.nullability, + r#type::Kind::Struct(r#struct) => r#struct.nullability, + r#type::Kind::List(list) => list.nullability, + r#type::Kind::Map(map) => map.nullability, + r#type::Kind::Func(func) => func.nullability, + r#type::Kind::UserDefined(user_defined) => user_defined.nullability, + #[expect(deprecated)] + r#type::Kind::UserDefinedTypeReference(_) => r#type::Nullability::Required as i32, + r#type::Kind::Alias(alias) => alias.nullability, + }; + + is_nullable(nullability) +} + +fn is_nullable(nullability: i32) -> datafusion::common::Result { + match r#type::Nullability::try_from(nullability) { + Ok(r#type::Nullability::Required) => Ok(false), + Ok(r#type::Nullability::Nullable | r#type::Nullability::Unspecified) => Ok(true), + Err(_) => not_impl_err!("Unsupported Substrait Nullability value {nullability}"), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use substrait::proto::r#type::Kind; + + #[test] + fn type_is_nullable_user_defined_type_reference_is_required() { + // The deprecated `UserDefinedTypeReference` variant doesn't carry a + // nullability field; the consumer hardcodes Required (non-null). + #[expect(deprecated)] + let dt = Type { + kind: Some(Kind::UserDefinedTypeReference(0)), + }; + assert!(!type_is_nullable(&dt).unwrap()); + } + + #[test] + fn type_is_nullable_missing_kind_defaults_to_nullable() { + // Defensive: a Type whose kind is None is treated as nullable. + let dt = Type { kind: None }; + assert!(type_is_nullable(&dt).unwrap()); + } + + #[test] + fn is_nullable_rejects_unrecognized_enum_value() { + let err = is_nullable(i32::MAX).unwrap_err(); + assert!( + err.to_string() + .contains("Unsupported Substrait Nullability"), + "got: {err}" + ); + } +} diff --git a/datafusion/substrait/src/logical_plan/consumer/utils.rs b/datafusion/substrait/src/logical_plan/consumer/utils.rs index 59cdf4a8fc93f..c654cc070938d 100644 --- a/datafusion/substrait/src/logical_plan/consumer/utils.rs +++ b/datafusion/substrait/src/logical_plan/consumer/utils.rs @@ -316,6 +316,10 @@ pub(super) fn ensure_schema_compatibility( /// 1. They have logically equivalent types. /// 2. They have the same nullability OR the Substrait field is nullable and the DataFusion fields /// is not nullable. +/// 3. For Struct fields, every child field's nullability is compatible by the same rule +/// (recursively). +/// +/// TODO: Check nullability for List and Map fields. /// /// If a Substrait field is not nullable, the Substrait plan may be built around assuming it is not /// nullable. As such if DataFusion has that field as nullable the plan should be rejected. @@ -339,15 +343,56 @@ fn ensure_field_compatibility( datafusion_field.is_nullable(), substrait_field.is_nullable(), ) { - // TODO: from_substrait_struct_type needs to be updated to set the nullability correctly. It defaults to true for now. return substrait_err!( "Field '{}' is nullable in the DataFusion schema but not nullable in the Substrait schema.", substrait_field.name() ); } + + ensure_nested_nullability_compatibility( + datafusion_field.data_type(), + substrait_field.data_type(), + substrait_field.name(), + ) +} + +/// Recurses through nested Struct DataTypes, applying +/// [`compatible_nullabilities`] to each child field. +/// +/// TODO: Add support for List/LargeList/FixedSizeList and Map fields. +fn ensure_nested_nullability_compatibility( + datafusion_type: &DataType, + substrait_type: &DataType, + field_path: &str, +) -> datafusion::common::Result<()> { + if let (DataType::Struct(df_fields), DataType::Struct(sub_fields)) = + (datafusion_type, substrait_type) + { + for (df_f, sub_f) in df_fields.iter().zip(sub_fields.iter()) { + check_nested_field(df_f, sub_f, field_path)?; + } + } Ok(()) } +fn check_nested_field( + df_field: &Field, + sub_field: &Field, + parent_path: &str, +) -> datafusion::common::Result<()> { + let path = format!("{parent_path}.{}", sub_field.name()); + if !compatible_nullabilities(df_field.is_nullable(), sub_field.is_nullable()) { + return substrait_err!( + "Field '{path}' is nullable in the DataFusion schema but not nullable in the Substrait schema." + ); + } + ensure_nested_nullability_compatibility( + df_field.data_type(), + sub_field.data_type(), + &path, + ) +} + /// Returns true if the DataFusion and Substrait nullabilities are compatible, false otherwise fn compatible_nullabilities( datafusion_nullability: bool, @@ -521,10 +566,10 @@ pub(crate) fn from_substrait_precision( #[cfg(test)] pub(crate) mod tests { - use super::{NameTracker, make_renamed_schema}; + use super::{NameTracker, ensure_schema_compatibility, make_renamed_schema}; use crate::extensions::Extensions; use crate::logical_plan::consumer::DefaultSubstraitConsumer; - use datafusion::arrow::datatypes::{DataType, Field}; + use datafusion::arrow::datatypes::{DataType, Field, Fields, Schema}; use datafusion::common::DFSchema; use datafusion::error::Result; use datafusion::execution::SessionState; @@ -813,4 +858,55 @@ pub(crate) mod tests { Ok(()) } + + fn schema_with_struct_inner(inner_nullable: bool) -> DFSchema { + let inner = Field::new("inner", DataType::Int32, inner_nullable); + let outer = Field::new("s", DataType::Struct(Fields::from(vec![inner])), false); + DFSchema::try_from(Schema::new(vec![outer])).unwrap() + } + + #[test] + fn nested_compatibility_accepts_required_df_field() -> Result<()> { + // DF makes a stronger guarantee (required) than Substrait expects + // (nullable). The stronger guarantee is compatible with the weaker + // expectation, so this is accepted. + let df = schema_with_struct_inner(false); + let sub = schema_with_struct_inner(true); + ensure_schema_compatibility(&df, sub) + } + + #[test] + fn nested_compatibility_rejects_nullable_df_field() { + // Substrait says inner is required; DF says inner is nullable. The + // Substrait plan may rely on inner being non-null, so reject. + let df = schema_with_struct_inner(true); + let sub = schema_with_struct_inner(false); + let err = ensure_schema_compatibility(&df, sub).unwrap_err(); + assert!( + err.to_string().contains("'s.inner'"), + "expected error to identify the nested field path 's.inner', got: {err}" + ); + } + + #[test] + fn nested_compatibility_recurses_into_nested_struct() { + // Two levels of nesting: outer struct with required field that is + // itself a struct, whose `inner` field is required in Substrait but + // nullable in DF. + fn schema(inner_nullable: bool) -> DFSchema { + let inner = Field::new("inner", DataType::Int32, inner_nullable); + let middle = + Field::new("m", DataType::Struct(Fields::from(vec![inner])), false); + let outer = + Field::new("s", DataType::Struct(Fields::from(vec![middle])), false); + DFSchema::try_from(Schema::new(vec![outer])).unwrap() + } + let df = schema(true); + let sub = schema(false); + let err = ensure_schema_compatibility(&df, sub).unwrap_err(); + assert!( + err.to_string().contains("'s.m.inner'"), + "expected error to identify the deeply nested field path 's.m.inner', got: {err}" + ); + } } diff --git a/datafusion/substrait/src/logical_plan/producer/types.rs b/datafusion/substrait/src/logical_plan/producer/types.rs index 53cb2eebfbda7..ebc7de5a291dd 100644 --- a/datafusion/substrait/src/logical_plan/producer/types.rs +++ b/datafusion/substrait/src/logical_plan/producer/types.rs @@ -534,4 +534,35 @@ mod tests { assert_eq!(schema.as_ref(), &roundtrip_schema); Ok(()) } + + #[test] + fn named_struct_unspecified_nullability_is_nullable() -> Result<()> { + let named_struct = NamedStruct { + names: vec!["unspecified".to_string(), "required".to_string()], + r#struct: Some(r#type::Struct { + types: vec![ + substrait::proto::Type { + kind: Some(r#type::Kind::I32(r#type::I32 { + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, + nullability: r#type::Nullability::Unspecified as i32, + })), + }, + substrait::proto::Type { + kind: Some(r#type::Kind::I32(r#type::I32 { + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, + nullability: r#type::Nullability::Required as i32, + })), + }, + ], + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, + nullability: r#type::Nullability::Required as i32, + }), + }; + + let schema = from_substrait_named_struct(&test_consumer(), &named_struct)?; + + assert!(schema.field(0).is_nullable()); + assert!(!schema.field(1).is_nullable()); + Ok(()) + } } diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index 1b8496c3dc729..1d65256d76420 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -1566,6 +1566,40 @@ async fn roundtrip_values_duplicate_column_join() -> Result<()> { Ok(()) } +#[tokio::test] +async fn roundtrip_preserves_field_nullability() -> Result<()> { + use datafusion::arrow::datatypes::Fields; + + // Verify that required and nullable fields, including nested struct fields, + // preserve their nullability through a Substrait round-trip. + // + // List child nullability is intentionally omitted because it is not + // preserved today. + let ctx = create_context().await?; + let df_schema = DFSchema::try_from(Schema::new(vec![ + Field::new("required_int", DataType::Int32, false), + Field::new("nullable_int", DataType::Int32, true), + Field::new( + "required_struct", + DataType::Struct(Fields::from(vec![ + Field::new("required_inner", DataType::Boolean, false), + Field::new("nullable_inner", DataType::Utf8, true), + ])), + false, + ), + ]))?; + let plan = LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: false, + schema: DFSchemaRef::new(df_schema), + }); + + let proto = to_substrait_plan(&plan, &ctx.state())?; + let plan2 = from_substrait_plan(&ctx.state(), &proto).await?; + + assert_eq!(plan.schema(), plan2.schema()); + Ok(()) +} + #[tokio::test] async fn duplicate_column() -> Result<()> { // Substrait does not keep column names (aliases) in the plan, rather it operates on column indices diff --git a/datafusion/substrait/tests/cases/substrait_validations.rs b/datafusion/substrait/tests/cases/substrait_validations.rs index 9841c736da8c9..081cc01a5edcf 100644 --- a/datafusion/substrait/tests/cases/substrait_validations.rs +++ b/datafusion/substrait/tests/cases/substrait_validations.rs @@ -62,7 +62,7 @@ mod tests { read_json("tests/testdata/test_plans/simple_select.substrait.json"); // this is the exact schema of the Substrait plan let df_schema = - vec![("a", DataType::Int32, false), ("b", DataType::Int32, true)]; + vec![("a", DataType::Int32, true), ("b", DataType::Int32, false)]; let ctx = generate_context_with_table("DATA", df_schema)?; let plan = from_substrait_plan(&ctx.state(), &proto_plan).await?; @@ -83,8 +83,8 @@ mod tests { read_json("tests/testdata/test_plans/simple_select.substrait.json"); // the DataFusion schema { b, a, c } contains the Substrait schema { a, b } let df_schema = vec![ - ("b", DataType::Int32, true), - ("a", DataType::Int32, false), + ("b", DataType::Int32, false), + ("a", DataType::Int32, true), ("c", DataType::Int32, false), ]; let ctx = generate_context_with_table("DATA", df_schema)?; @@ -150,5 +150,22 @@ mod tests { assert!(res.is_err()); Ok(()) } + + #[tokio::test] + async fn reject_plans_with_incompatible_field_nullability() -> Result<()> { + let proto_plan = + read_json("tests/testdata/test_plans/simple_select.substrait.json"); + let df_schema = + vec![("a", DataType::Int32, true), ("b", DataType::Int32, true)]; + + let ctx = generate_context_with_table("DATA", df_schema)?; + let res = from_substrait_plan(&ctx.state(), &proto_plan).await; + + assert_snapshot!( + res.unwrap_err().strip_backtrace(), + @r#"Substrait error: Field 'b' is nullable in the DataFusion schema but not nullable in the Substrait schema."# + ); + Ok(()) + } } }