Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions compiler/rustc_codegen_llvm/src/intrinsic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1015,6 +1015,24 @@ fn can_autocast<'ll>(cx: &CodegenCx<'ll, '_>, rust_ty: &'ll Type, llvm_ty: &'ll
}
}
TypeKind::BFloat => rust_ty == cx.type_i16(),
TypeKind::X86_AMX if cx.type_kind(rust_ty) == TypeKind::Vector => {
let element_ty = cx.element_type(rust_ty);
let element_count = cx.vector_length(rust_ty) as u64;

let element_size_bits = match cx.type_kind(element_ty) {
TypeKind::Half => 16,
TypeKind::Float => 32,
TypeKind::Double => 64,
TypeKind::FP128 => 128,
TypeKind::Integer => cx.int_width(element_ty),
TypeKind::Pointer => cx.int_width(cx.isize_ty),
_ => bug!(
"Vector element type `{element_ty:?}` not one of integer, float or pointer"
),
};

element_size_bits * element_count == 8192
}
_ => false,
}
}
Expand Down Expand Up @@ -1084,6 +1102,12 @@ fn autocast<'ll>(
)
}
}
(TypeKind::Vector, TypeKind::X86_AMX) => {
bx.call_intrinsic("llvm.x86.cast.vector.to.tile", &[src_ty], &[val])
}
(TypeKind::X86_AMX, TypeKind::Vector) => {
bx.call_intrinsic("llvm.x86.cast.tile.to.vector", &[dest_ty], &[val])
}
_ => bx.bitcast(val, dest_ty), // for `bf16(xN)` <-> `u16(xN)`
}
}
Expand Down
13 changes: 11 additions & 2 deletions compiler/rustc_monomorphize/src/mono_checks/abi_check.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! This module ensures that if a function's ABI requires a particular target feature,
//! that target feature is enabled both on the callee and all callers.
use rustc_abi::{BackendRepr, CanonAbi, RegKind, X86Call};
use rustc_abi::{BackendRepr, CanonAbi, ExternAbi, RegKind, X86Call};
use rustc_hir::{CRATE_HIR_ID, HirId};
use rustc_middle::mir::{self, Location, traversal};
use rustc_middle::ty::{self, Instance, InstanceKind, Ty, TyCtxt};
Expand Down Expand Up @@ -157,6 +157,12 @@ fn do_check_unsized_params<'tcx>(
/// - the signature requires target features that are not enabled
fn check_instance_abi<'tcx>(tcx: TyCtxt<'tcx>, instance: Instance<'tcx>) {
let typing_env = ty::TypingEnv::fully_monomorphized();
let ty = instance.ty(tcx, typing_env);
if ty.is_fn() && ty.fn_sig(tcx).abi() == ExternAbi::Unadjusted {
// We disable all checks for the unadjusted ABI to allow linking to arbitrary LLVM
// intrinsics
return;
}
let Ok(abi) = tcx.fn_abi_of_instance(typing_env.as_query_input((instance, ty::List::empty())))
else {
// An error will be reported during codegen if we cannot determine the ABI of this
Expand Down Expand Up @@ -191,9 +197,12 @@ fn check_call_site_abi<'tcx>(
caller: InstanceKind<'tcx>,
loc: impl Fn() -> (Span, HirId) + Copy,
) {
if callee.fn_sig(tcx).abi().is_rustic_abi() {
let extern_abi = callee.fn_sig(tcx).abi();
if extern_abi.is_rustic_abi() || extern_abi == ExternAbi::Unadjusted {
// We directly handle the soundness of Rust ABIs -- so let's skip the majority of
// call sites to avoid a perf regression.
// We disable all checks for the unadjusted ABI to allow linking to arbitrary LLVM
// intrinsics
return;
}
let typing_env = ty::TypingEnv::fully_monomorphized();
Expand Down
29 changes: 27 additions & 2 deletions tests/codegen-llvm/inject-autocast.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
//@ compile-flags: -C opt-level=0 -C target-feature=+kl,+avx512vp2intersect,+avx512vl,+avxneconvert
//@ compile-flags: -C opt-level=0 -C target-feature=+kl,+avx512vp2intersect,+avx512vl,+avx512dq,+avxneconvert,+amx-int8
//@ only-x86_64

#![feature(link_llvm_intrinsics, abi_unadjusted, simd_ffi, portable_simd)]
#![feature(link_llvm_intrinsics, abi_unadjusted, simd_ffi, portable_simd, repr_simd)]
#![crate_type = "lib"]

use std::simd::{f32x4, i16x8, i64x2};
Expand All @@ -10,6 +10,9 @@ use std::simd::{f32x4, i16x8, i64x2};
pub struct Bar(u32, i64x2, i64x2, i64x2, i64x2, i64x2, i64x2);
// CHECK: %Bar = type <{ i32, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64> }>

#[repr(simd)]
pub struct Tile([i8; 1024]);

// CHECK-LABEL: @struct_autocast
#[no_mangle]
pub unsafe fn struct_autocast(key_metadata: u32, key: i64x2) -> Bar {
Expand Down Expand Up @@ -84,10 +87,32 @@ pub unsafe fn bf16_vector_autocast(a: f32x4) -> i16x8 {
foo(a)
}

// CHECK-LABEL: @amx_autocast
#[no_mangle]
pub unsafe fn amx_autocast(m: u16, n: u16, k: u16, a: Tile, b: Tile, c: Tile) -> Tile {
extern "unadjusted" {
#[link_name = "llvm.x86.tdpbuud.internal"]
fn foo(m: u16, n: u16, k: u16, a: Tile, b: Tile, c: Tile) -> Tile;
}

// CHECK: [[A:%[0-9]+]] = call x86_amx @llvm.x86.cast.vector.to.tile.v1024i8(<1024 x i8> {{.*}})
// CHECK: [[B:%[0-9]+]] = call x86_amx @llvm.x86.cast.vector.to.tile.v1024i8(<1024 x i8> {{.*}})
// CHECK: [[C:%[0-9]+]] = call x86_amx @llvm.x86.cast.vector.to.tile.v1024i8(<1024 x i8> {{.*}})
// CHECK: [[D:%[0-9]+]] = call x86_amx @llvm.x86.tdpbuud.internal(i16 %m, i16 %n, i16 %k, x86_amx [[A]], x86_amx [[B]], x86_amx [[C]])
// CHECK: call <1024 x i8> @llvm.x86.cast.tile.to.vector.v1024i8(x86_amx [[D]])
foo(m, n, k, a, b, c)
}

// CHECK: declare { i32, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64> } @llvm.x86.encodekey128(i32, <2 x i64>)

// CHECK: declare { <2 x i1>, <2 x i1> } @llvm.x86.avx512.vp2intersect.q.128(<2 x i64>, <2 x i64>)

// CHECK: declare <8 x i1> @llvm.x86.avx512.kadd.b(<8 x i1>, <8 x i1>)

// CHECK: declare <8 x bfloat> @llvm.x86.vcvtneps2bf16128(<4 x float>)

// CHECK: declare x86_amx @llvm.x86.tdpbuud.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx)

// CHECK: declare x86_amx @llvm.x86.cast.vector.to.tile.v1024i8(<1024 x i8>)

// CHECK: declare <1024 x i8> @llvm.x86.cast.tile.to.vector.v1024i8(x86_amx)
Loading