Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add const-eval support for SIMD types, insert, and extract #64738

Merged
merged 12 commits into from
Sep 25, 2019
45 changes: 45 additions & 0 deletions src/librustc_mir/interpret/intrinsics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,52 @@ impl<'mir, 'tcx, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
"transmute" => {
self.copy_op_transmute(args[0], dest)?;
}
"simd_insert" => {
let index = self.read_scalar(args[1])?.to_u32()? as u64;
let scalar = args[2];
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is quite confusing to call something of type OpTy a "scalar". Those should have type Scalar, or this should be called somewhere else, or there should at least be a comment.

let input = args[0];
let (len, e_ty) = self.read_vector_ty(input);
assert!(
index < len,
"Index `{}` must be in bounds of vector type `{}`: `[0, {})`",
index, e_ty, len
);
assert_eq!(
input.layout, dest.layout,
"Return type `{}` must match vector type `{}`",
dest.layout.ty, input.layout.ty
);
assert_eq!(
scalar.layout.ty, e_ty,
"Scalar type `{}` must match vector element type `{}`",
scalar.layout.ty, e_ty
);

for i in 0..len {
let place = self.place_field(dest, i)?;
let value = if i == index {
scalar
} else {
self.operand_field(input, i)?
};
self.copy_op(value, place)?;
}
}
"simd_extract" => {
let index = self.read_scalar(args[1])?.to_u32()? as _;
let (len, e_ty) = self.read_vector_ty(args[0]);
assert!(
index < len,
"index `{}` is out-of-bounds of vector type `{}` with length `{}`",
index, e_ty, len
);
assert_eq!(
e_ty, dest.layout.ty,
"Return type `{}` must match vector element type `{}`",
dest.layout.ty, e_ty
);
self.copy_op(self.operand_field(args[0], index)?, dest)?;
}
_ => return Ok(false),
}

Expand Down
11 changes: 11 additions & 0 deletions src/librustc_mir/interpret/operand.rs
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,17 @@ impl<'mir, 'tcx, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
}
}

/// Read vector length and element type
pub fn read_vector_ty(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't even look at the actual "operand", right? It is just a layout function?

Then it (a) should be mixed up with all the other read_* methods as it is very different, and (b) shouldn't even be called "read" as it doesn't read from memory.

I am not sure what the best place is for such a method, but it is not operand.rs. Sounds more like a TyLayout method to me?

&self, op: OpTy<'tcx, M::PointerTag>
) -> (u64, &rustc::ty::TyS<'tcx>) {
if let layout::Abi::Vector { .. } = op.layout.abi {
(op.layout.ty.simd_size(*self.tcx) as _, op.layout.ty.simd_type(*self.tcx))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason why this is not using the element count from Abi::Vector?

} else {
bug!("Type `{}` is not a SIMD vector type", op.layout.ty)
}
}

/// Read a scalar from a place
pub fn read_scalar(
&self,
Expand Down
3 changes: 0 additions & 3 deletions src/librustc_mir/interpret/terminator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -249,9 +249,6 @@ impl<'mir, 'tcx, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {

match instance.def {
ty::InstanceDef::Intrinsic(..) => {
if caller_abi != Abi::RustIntrinsic {
throw_unsup!(FunctionAbiMismatch(caller_abi, Abi::RustIntrinsic))
}
Comment on lines -252 to -254
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Uh, why this? Instead of throwing out this check entirely, you could have just added platform intrinsics to the whitelist.

// The intrinsic itself cannot diverge, so if we got here without a return
// place... (can happen e.g., for transmute returning `!`)
let dest = match dest {
Expand Down
2 changes: 2 additions & 0 deletions src/librustc_mir/transform/qualify_consts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,8 @@ impl Qualif for IsNotPromotable {
| "saturating_add"
| "saturating_sub"
| "transmute"
| "simd_insert"
| "simd_extract"
=> return true,

_ => {}
Expand Down
53 changes: 53 additions & 0 deletions src/test/ui/consts/const-eval/simd/insert_extract.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
// run-pass
#![feature(const_fn)]
#![feature(repr_simd)]
#![feature(platform_intrinsics)]
#![allow(non_camel_case_types)]

#[repr(simd)] struct i8x1(i8);
#[repr(simd)] struct u16x2(u16, u16);
#[repr(simd)] struct f32x3(f32, f32, f32);

extern "platform-intrinsic" {
fn simd_insert<T, U>(x: T, idx: u32, val: U) -> T;
fn simd_extract<T, U>(x: T, idx: u32) -> U;
}

fn main() {
{
const U: i8x1 = i8x1(13);
const V: i8x1 = unsafe { simd_insert(U, 0_u32, 42_i8) };
const X0: i8 = V.0;
const Y0: i8 = unsafe { simd_extract(V, 0) };
assert_eq!(X0, 42);
assert_eq!(Y0, 42);
}
{
const U: u16x2 = u16x2(13, 14);
const V: u16x2 = unsafe { simd_insert(U, 1_u32, 42_u16) };
const X0: u16 = V.0;
const X1: u16 = V.1;
const Y0: u16 = unsafe { simd_extract(V, 0) };
const Y1: u16 = unsafe { simd_extract(V, 1) };
assert_eq!(X0, 13);
assert_eq!(X1, 42);
assert_eq!(Y0, 13);
assert_eq!(Y1, 42);
}
{
const U: f32x3 = f32x3(13., 14., 15.);
const V: f32x3 = unsafe { simd_insert(U, 1_u32, 42_f32) };
const X0: f32 = V.0;
const X1: f32 = V.1;
const X2: f32 = V.2;
const Y0: f32 = unsafe { simd_extract(V, 0) };
const Y1: f32 = unsafe { simd_extract(V, 1) };
const Y2: f32 = unsafe { simd_extract(V, 2) };
assert_eq!(X0, 13.);
assert_eq!(X1, 42.);
assert_eq!(X2, 15.);
assert_eq!(Y0, 13.);
assert_eq!(Y1, 42.);
assert_eq!(Y2, 15.);
}
}