Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 89 additions & 5 deletions datafusion/substrait/src/logical_plan/consumer/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -347,12 +347,96 @@ fn from_substrait_struct_type(
) -> datafusion::common::Result<Fields> {
let mut fields = vec![];
for (i, f) in s.types.iter().enumerate() {
let field = Field::new(
next_struct_field_name(i, dfs_names, name_idx)?,
from_substrait_type(consumer, f, dfs_names, name_idx)?,
true, // We assume everything to be nullable since that's easier than ensuring it matches
);
let name = next_struct_field_name(i, dfs_names, name_idx)?;
let data_type = from_substrait_type(consumer, f, dfs_names, name_idx)?;
let field = Field::new(name, data_type, type_is_nullable(f));
fields.push(field);
}
Ok(fields.into())
}

fn type_is_nullable(dt: &Type) -> bool {
let Some(kind) = dt.kind.as_ref() else {
return true;
};

let nullability = match kind {
r#type::Kind::Bool(boolean) => boolean.nullability,
r#type::Kind::I8(integer) => integer.nullability,
r#type::Kind::I16(integer) => integer.nullability,
r#type::Kind::I32(integer) => integer.nullability,
r#type::Kind::I64(integer) => integer.nullability,
r#type::Kind::Fp32(float) => float.nullability,
r#type::Kind::Fp64(float) => float.nullability,
#[expect(deprecated)]
r#type::Kind::Timestamp(timestamp) => timestamp.nullability,
r#type::Kind::Date(date) => date.nullability,
#[expect(deprecated)]
r#type::Kind::Time(time) => time.nullability,
#[expect(deprecated)]
r#type::Kind::TimestampTz(timestamp) => timestamp.nullability,
r#type::Kind::IntervalYear(interval) => interval.nullability,
r#type::Kind::IntervalDay(interval) => interval.nullability,
r#type::Kind::IntervalCompound(interval) => interval.nullability,
r#type::Kind::Uuid(uuid) => uuid.nullability,
r#type::Kind::String(string) => string.nullability,
r#type::Kind::Binary(binary) => binary.nullability,
r#type::Kind::FixedChar(fixed) => fixed.nullability,
r#type::Kind::Varchar(varchar) => varchar.nullability,
r#type::Kind::FixedBinary(fixed) => fixed.nullability,
r#type::Kind::Decimal(decimal) => decimal.nullability,
r#type::Kind::PrecisionTime(time) => time.nullability,
r#type::Kind::PrecisionTimestamp(timestamp) => timestamp.nullability,
r#type::Kind::PrecisionTimestampTz(timestamp) => timestamp.nullability,
r#type::Kind::Struct(r#struct) => r#struct.nullability,
r#type::Kind::List(list) => list.nullability,
r#type::Kind::Map(map) => map.nullability,
r#type::Kind::Func(func) => func.nullability,
r#type::Kind::UserDefined(user_defined) => user_defined.nullability,
#[expect(deprecated)]
r#type::Kind::UserDefinedTypeReference(_) => r#type::Nullability::Required as i32,
r#type::Kind::Alias(alias) => alias.nullability,
};

is_nullable(nullability)
}

fn is_nullable(nullability: i32) -> bool {
match r#type::Nullability::try_from(nullability) {
Ok(r#type::Nullability::Required) => false,
Ok(r#type::Nullability::Nullable | r#type::Nullability::Unspecified) | Err(_) => {
true
}
}
}

#[cfg(test)]
mod tests {
use super::*;
use substrait::proto::r#type::Kind;

#[test]
fn type_is_nullable_user_defined_type_reference_is_required() {
// The deprecated `UserDefinedTypeReference` variant doesn't carry a
// nullability field; the consumer hardcodes Required (non-null).
#[expect(deprecated)]
let dt = Type {
kind: Some(Kind::UserDefinedTypeReference(0)),
};
assert!(!type_is_nullable(&dt));
}

#[test]
fn type_is_nullable_missing_kind_defaults_to_nullable() {
// Defensive: a Type whose kind is None is treated as nullable.
let dt = Type { kind: None };
assert!(type_is_nullable(&dt));
}

#[test]
fn is_nullable_handles_unrecognized_enum_value() {
// Defensive: an unrecognized Nullability enum value (one prost
// doesn't know about) is treated as nullable.
assert!(is_nullable(i32::MAX));
}
}
102 changes: 99 additions & 3 deletions datafusion/substrait/src/logical_plan/consumer/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,10 @@ pub(super) fn ensure_schema_compatibility(
/// 1. They have logically equivalent types.
/// 2. They have the same nullability OR the Substrait field is nullable and the DataFusion fields
/// is not nullable.
/// 3. For Struct fields, every child field's nullability is compatible by the same rule
/// (recursively).
///
/// TODO: Check nullability for List and Map fields.
///
/// If a Substrait field is not nullable, the Substrait plan may be built around assuming it is not
/// nullable. As such if DataFusion has that field as nullable the plan should be rejected.
Expand All @@ -339,15 +343,56 @@ fn ensure_field_compatibility(
datafusion_field.is_nullable(),
substrait_field.is_nullable(),
) {
// TODO: from_substrait_struct_type needs to be updated to set the nullability correctly. It defaults to true for now.
return substrait_err!(
"Field '{}' is nullable in the DataFusion schema but not nullable in the Substrait schema.",
substrait_field.name()
);
}

ensure_nested_nullability_compatibility(
datafusion_field.data_type(),
substrait_field.data_type(),
substrait_field.name(),
)
}

/// Recurses through nested Struct DataTypes, applying
/// [`compatible_nullabilities`] to each child field.
///
/// TODO: Add support for List/LargeList/FixedSizeList and Map fields.
fn ensure_nested_nullability_compatibility(
datafusion_type: &DataType,
substrait_type: &DataType,
field_path: &str,
) -> datafusion::common::Result<()> {
if let (DataType::Struct(df_fields), DataType::Struct(sub_fields)) =
(datafusion_type, substrait_type)
{
for (df_f, sub_f) in df_fields.iter().zip(sub_fields.iter()) {
check_nested_field(df_f, sub_f, field_path)?;
}
}
Ok(())
}

fn check_nested_field(
df_field: &Field,
sub_field: &Field,
parent_path: &str,
) -> datafusion::common::Result<()> {
let path = format!("{parent_path}.{}", sub_field.name());
if !compatible_nullabilities(df_field.is_nullable(), sub_field.is_nullable()) {
return substrait_err!(
"Field '{path}' is nullable in the DataFusion schema but not nullable in the Substrait schema."
);
}
ensure_nested_nullability_compatibility(
df_field.data_type(),
sub_field.data_type(),
&path,
)
}

/// Returns true if the DataFusion and Substrait nullabilities are compatible, false otherwise
fn compatible_nullabilities(
datafusion_nullability: bool,
Expand Down Expand Up @@ -521,10 +566,10 @@ pub(crate) fn from_substrait_precision(

#[cfg(test)]
pub(crate) mod tests {
use super::{NameTracker, make_renamed_schema};
use super::{NameTracker, ensure_schema_compatibility, make_renamed_schema};
use crate::extensions::Extensions;
use crate::logical_plan::consumer::DefaultSubstraitConsumer;
use datafusion::arrow::datatypes::{DataType, Field};
use datafusion::arrow::datatypes::{DataType, Field, Fields, Schema};
use datafusion::common::DFSchema;
use datafusion::error::Result;
use datafusion::execution::SessionState;
Expand Down Expand Up @@ -813,4 +858,55 @@ pub(crate) mod tests {

Ok(())
}

fn schema_with_struct_inner(inner_nullable: bool) -> DFSchema {
let inner = Field::new("inner", DataType::Int32, inner_nullable);
let outer = Field::new("s", DataType::Struct(Fields::from(vec![inner])), false);
DFSchema::try_from(Schema::new(vec![outer])).unwrap()
}

#[test]
fn nested_compatibility_accepts_required_df_field() -> Result<()> {
// DF makes a stronger guarantee (required) than Substrait expects
// (nullable). The stronger guarantee is compatible with the weaker
// expectation, so this is accepted.
let df = schema_with_struct_inner(false);
let sub = schema_with_struct_inner(true);
ensure_schema_compatibility(&df, sub)
}

#[test]
fn nested_compatibility_rejects_nullable_df_field() {
// Substrait says inner is required; DF says inner is nullable. The
// Substrait plan may rely on inner being non-null, so reject.
let df = schema_with_struct_inner(true);
let sub = schema_with_struct_inner(false);
let err = ensure_schema_compatibility(&df, sub).unwrap_err();
assert!(
err.to_string().contains("'s.inner'"),
"expected error to identify the nested field path 's.inner', got: {err}"
);
}

#[test]
fn nested_compatibility_recurses_into_nested_struct() {
// Two levels of nesting: outer struct with required field that is
// itself a struct, whose `inner` field is required in Substrait but
// nullable in DF.
fn schema(inner_nullable: bool) -> DFSchema {
let inner = Field::new("inner", DataType::Int32, inner_nullable);
let middle =
Field::new("m", DataType::Struct(Fields::from(vec![inner])), false);
let outer =
Field::new("s", DataType::Struct(Fields::from(vec![middle])), false);
DFSchema::try_from(Schema::new(vec![outer])).unwrap()
}
let df = schema(true);
let sub = schema(false);
let err = ensure_schema_compatibility(&df, sub).unwrap_err();
assert!(
err.to_string().contains("'s.m.inner'"),
"expected error to identify the deeply nested field path 's.m.inner', got: {err}"
);
}
}
31 changes: 31 additions & 0 deletions datafusion/substrait/src/logical_plan/producer/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -534,4 +534,35 @@ mod tests {
assert_eq!(schema.as_ref(), &roundtrip_schema);
Ok(())
}

#[test]
fn named_struct_unspecified_nullability_is_nullable() -> Result<()> {
let named_struct = NamedStruct {
names: vec!["unspecified".to_string(), "required".to_string()],
r#struct: Some(r#type::Struct {
types: vec![
substrait::proto::Type {
kind: Some(r#type::Kind::I32(r#type::I32 {
type_variation_reference: DEFAULT_TYPE_VARIATION_REF,
nullability: r#type::Nullability::Unspecified as i32,
})),
},
substrait::proto::Type {
kind: Some(r#type::Kind::I32(r#type::I32 {
type_variation_reference: DEFAULT_TYPE_VARIATION_REF,
nullability: r#type::Nullability::Required as i32,
})),
},
],
type_variation_reference: DEFAULT_TYPE_VARIATION_REF,
nullability: r#type::Nullability::Required as i32,
}),
};

let schema = from_substrait_named_struct(&test_consumer(), &named_struct)?;

assert!(schema.field(0).is_nullable());
assert!(!schema.field(1).is_nullable());
Ok(())
}
}
34 changes: 34 additions & 0 deletions datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1566,6 +1566,40 @@ async fn roundtrip_values_duplicate_column_join() -> Result<()> {
Ok(())
}

#[tokio::test]
async fn roundtrip_preserves_field_nullability() -> Result<()> {
use datafusion::arrow::datatypes::Fields;

// Verify that required and nullable fields, including nested struct fields,
// preserve their nullability through a Substrait round-trip.
//
// List child nullability is intentionally omitted because it is not
// preserved today.
let ctx = create_context().await?;
let df_schema = DFSchema::try_from(Schema::new(vec![
Field::new("required_int", DataType::Int32, false),
Field::new("nullable_int", DataType::Int32, true),
Field::new(
"required_struct",
DataType::Struct(Fields::from(vec![
Field::new("required_inner", DataType::Boolean, false),
Field::new("nullable_inner", DataType::Utf8, true),
])),
false,
),
]))?;
let plan = LogicalPlan::EmptyRelation(EmptyRelation {
produce_one_row: false,
schema: DFSchemaRef::new(df_schema),
});

let proto = to_substrait_plan(&plan, &ctx.state())?;
let plan2 = from_substrait_plan(&ctx.state(), &proto).await?;

assert_eq!(plan.schema(), plan2.schema());
Ok(())
}

#[tokio::test]
async fn duplicate_column() -> Result<()> {
// Substrait does not keep column names (aliases) in the plan, rather it operates on column indices
Expand Down
23 changes: 20 additions & 3 deletions datafusion/substrait/tests/cases/substrait_validations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ mod tests {
read_json("tests/testdata/test_plans/simple_select.substrait.json");
// this is the exact schema of the Substrait plan
let df_schema =
vec![("a", DataType::Int32, false), ("b", DataType::Int32, true)];
vec![("a", DataType::Int32, true), ("b", DataType::Int32, false)];

let ctx = generate_context_with_table("DATA", df_schema)?;
let plan = from_substrait_plan(&ctx.state(), &proto_plan).await?;
Expand All @@ -83,8 +83,8 @@ mod tests {
read_json("tests/testdata/test_plans/simple_select.substrait.json");
// the DataFusion schema { b, a, c } contains the Substrait schema { a, b }
let df_schema = vec![
("b", DataType::Int32, true),
("a", DataType::Int32, false),
("b", DataType::Int32, false),
("a", DataType::Int32, true),
("c", DataType::Int32, false),
];
let ctx = generate_context_with_table("DATA", df_schema)?;
Expand Down Expand Up @@ -150,5 +150,22 @@ mod tests {
assert!(res.is_err());
Ok(())
}

#[tokio::test]
async fn reject_plans_with_incompatible_field_nullability() -> Result<()> {
let proto_plan =
read_json("tests/testdata/test_plans/simple_select.substrait.json");
let df_schema =
vec![("a", DataType::Int32, true), ("b", DataType::Int32, true)];

let ctx = generate_context_with_table("DATA", df_schema)?;
let res = from_substrait_plan(&ctx.state(), &proto_plan).await;

assert_snapshot!(
res.unwrap_err().strip_backtrace(),
@r#"Substrait error: Field 'b' is nullable in the DataFusion schema but not nullable in the Substrait schema."#
);
Ok(())
}
}
}
Loading