diff --git a/crates/rustc_codegen_spirv/src/builder/builder_methods.rs b/crates/rustc_codegen_spirv/src/builder/builder_methods.rs index d86db1cbd00..0d5341d3e89 100644 --- a/crates/rustc_codegen_spirv/src/builder/builder_methods.rs +++ b/crates/rustc_codegen_spirv/src/builder/builder_methods.rs @@ -2423,7 +2423,32 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { "pointercast called on non-pointer dest type: {other:?}" )), }; - let dest_pointee_size = self.lookup_type(dest_pointee).sizeof(self); + + let dst_pointee_ty = self.lookup_type(dest_pointee); + let dest_pointee_size = dst_pointee_ty.sizeof(self); + let src_pointee_ty = self.lookup_type(ptr_pointee); + + // *[T; N] -> *RuntimeArray + if let SpirvType::Array { + element: elem_ty, .. + } = src_pointee_ty + && let SpirvType::RuntimeArray { + element: rt_elem_ty, + } = dst_pointee_ty + && elem_ty == rt_elem_ty + { + let zero = self.constant_u32(self.span(), 0).def(self); + let elem_ptr_ty = self.type_ptr_to(elem_ty); + let elem_ptr = self + .emit() + .in_bounds_access_chain(elem_ptr_ty, None, ptr.def(self), [zero]) + .unwrap(); + return self + .emit() + .bitcast(dest_ty, None, elem_ptr) + .unwrap() + .with_type(dest_ty); + } if let Some((indices, _)) = self.recover_access_chain_from_offset( ptr_pointee, diff --git a/crates/rustc_codegen_spirv/src/linker/mod.rs b/crates/rustc_codegen_spirv/src/linker/mod.rs index 98e076bf5fe..4aba5b80858 100644 --- a/crates/rustc_codegen_spirv/src/linker/mod.rs +++ b/crates/rustc_codegen_spirv/src/linker/mod.rs @@ -660,6 +660,7 @@ pub fn link( peephole_opts::composite_construct(&types, func); peephole_opts::vector_ops(output.header.as_mut().unwrap(), &types, func); peephole_opts::bool_fusion(output.header.as_mut().unwrap(), &types, func); + peephole_opts::fold_array_bitcast_access_chain(&types, func); } } diff --git a/crates/rustc_codegen_spirv/src/linker/peephole_opts.rs b/crates/rustc_codegen_spirv/src/linker/peephole_opts.rs index 6162d55a377..e5037d7a7c3 100644 --- a/crates/rustc_codegen_spirv/src/linker/peephole_opts.rs +++ b/crates/rustc_codegen_spirv/src/linker/peephole_opts.rs @@ -680,3 +680,105 @@ pub fn fold_load_from_constant_variable(module: &mut Module) { } } } + +/// Eliminate the `OpBitcast` that arises from `*[T; N] → *RuntimeArray` pointer casts. +/// +/// When a local array is coerced to a slice (`&[T;N] as &[T]`), codegen emits: +/// ```text +/// %elem0 = OpInBoundsAccessChain %arr 0 ; pointer to arr[0] +/// %rta = OpBitcast %elem0 ; *RuntimeArray (invalid in Logical SPIR-V) +/// ``` +/// After inlining the slice-taking function, indexing becomes: +/// ```text +/// %ei = OpInBoundsAccessChain %rta i ; data[i] +/// +pub fn fold_array_bitcast_access_chain( + types: &FxHashMap, + function: &mut Function, +) { + let func_defs: FxHashMap = function + .all_inst_iter() + .filter_map(|inst| Some((inst.result_id?, inst.clone()))) + .collect(); + + // look up an ID in either function-local defs or module-level types/globals. + let lookup = + |id: Word| -> Option<&Instruction> { func_defs.get(&id).or_else(|| types.get(&id)) }; + + for block in &mut function.blocks { + for inst in &mut block.instructions { + if !matches!(inst.class.opcode, Op::AccessChain | Op::InBoundsAccessChain) { + continue; + } + if inst.operands.is_empty() { + continue; + } + let base_id = inst.operands[0].unwrap_id_ref(); + + // base must be an OpBitcast + let Some(bitcast) = lookup(base_id) else { + continue; + }; + if bitcast.class.opcode != Op::Bitcast { + continue; + } + + // bitcast result type must be *SC RuntimeArray + let Some(bitcast_dst_ptr) = lookup(bitcast.result_type.unwrap()) else { + continue; + }; + if bitcast_dst_ptr.class.opcode != Op::TypePointer { + continue; + } + let rta_type_id = bitcast_dst_ptr.operands[1].unwrap_id_ref(); + let Some(rta_ty) = lookup(rta_type_id) else { + continue; + }; + if rta_ty.class.opcode != Op::TypeRuntimeArray { + continue; + } + let rta_elem_ty = rta_ty.operands[0].unwrap_id_ref(); + + // bitcast source must be OpInBoundsAccessChain(arr, 0) + let bitcast_src_id = bitcast.operands[0].unwrap_id_ref(); + let Some(inner_ac) = lookup(bitcast_src_id) else { + continue; + }; + if inner_ac.class.opcode != Op::InBoundsAccessChain { + continue; + } + // Exactly one index operand + if inner_ac.operands.len() != 2 { + continue; + } + // That index must be the constant 0 + let idx0_id = inner_ac.operands[1].unwrap_id_ref(); + let Some(idx0) = lookup(idx0_id) else { + continue; + }; + if idx0.class.opcode != Op::Constant { + continue; + } + if !matches!(idx0.operands[0], Operand::LiteralBit32(0)) { + continue; + } + + // inner AccessChain result type must be *SC T where T == rta_elem_ty + let Some(inner_dst_ptr) = lookup(inner_ac.result_type.unwrap()) else { + continue; + }; + if inner_dst_ptr.class.opcode != Op::TypePointer { + continue; + } + let elem_ty = inner_dst_ptr.operands[1].unwrap_id_ref(); + if elem_ty != rta_elem_ty { + continue; + } + + // AccessChain(Bitcast(InBoundsAccessChain(arr, 0)), i) + // Replace base with arr — the dead bitcast and intermediate AC are cleaned by DCE. + let arr_id = inner_ac.operands[0].unwrap_id_ref(); + inst.operands[0] = Operand::IdRef(arr_id); + } + } +} diff --git a/tests/compiletests/ui/lang/core/array-slice-cast.rs b/tests/compiletests/ui/lang/core/array-slice-cast.rs new file mode 100644 index 00000000000..8a7f039fb66 --- /dev/null +++ b/tests/compiletests/ui/lang/core/array-slice-cast.rs @@ -0,0 +1,33 @@ +// build-pass +// compile-flags: -C llvm-args=--disassemble +// normalize-stderr-test "OpSource .*\n" -> "" +// normalize-stderr-test "OpLine .*\n" -> "" +// normalize-stderr-test "%\d+ = OpString .*\n" -> "" +// normalize-stderr-test "^(; .*\n)*" -> "" +// normalize-stderr-test "OpCapability VulkanMemoryModel\n" -> "" +// normalize-stderr-test "OpMemoryModel Logical Vulkan" -> "OpMemoryModel Logical Simple" +// ignore-spv1.0 +// ignore-spv1.1 +// ignore-spv1.2 +// ignore-spv1.3 +// ignore-spv1.4 +// ignore-spv1.5 +// ignore-spv1.6 +// ignore-vulkan1.0 +// ignore-vulkan1.1 +use spirv_std::spirv; + +fn do_work(data: &[u32], slab: &mut [u32]) { + slab[0] = data[0]; + slab[1] = data[1]; + slab[2] = data[2]; +} + +#[spirv(compute(threads(64)))] +pub fn compute_shader( + #[spirv(storage_buffer, descriptor_set = 0, binding = 0)] slab: &mut [u32], + #[spirv(global_invocation_id)] global_id: glam::UVec3, +) { + let data = [global_id.x, global_id.y, global_id.z]; + do_work(&data, slab); +} diff --git a/tests/compiletests/ui/lang/core/array-slice-cast.stderr b/tests/compiletests/ui/lang/core/array-slice-cast.stderr new file mode 100644 index 00000000000..9b7c7b3782a --- /dev/null +++ b/tests/compiletests/ui/lang/core/array-slice-cast.stderr @@ -0,0 +1,116 @@ +OpCapability Shader +OpMemoryModel Logical Simple +OpEntryPoint GLCompute %1 "compute_shader" %2 %3 +OpExecutionMode %1 LocalSize 64 1 1 +OpName %2 "slab" +OpName %3 "global_id" +OpDecorate %5 ArrayStride 4 +OpDecorate %6 Block +OpMemberDecorate %6 0 Offset 0 +OpDecorate %2 Binding 0 +OpDecorate %2 DescriptorSet 0 +OpDecorate %3 BuiltIn GlobalInvocationId +%7 = OpTypeInt 32 0 +%5 = OpTypeRuntimeArray %7 +%6 = OpTypeStruct %5 +%8 = OpTypePointer StorageBuffer %6 +%9 = OpTypeVector %7 3 +%10 = OpTypePointer Input %9 +%11 = OpTypeVoid +%12 = OpTypeFunction %11 +%13 = OpConstant %7 3 +%14 = OpTypeArray %7 %13 +%15 = OpTypePointer Function %14 +%16 = OpTypePointer StorageBuffer %5 +%2 = OpVariable %8 StorageBuffer +%17 = OpConstant %7 0 +%3 = OpVariable %10 Input +%18 = OpTypePointer Function %7 +%19 = OpConstant %7 1 +%20 = OpConstant %7 2 +%21 = OpTypeBool +%22 = OpTypePointer StorageBuffer %7 +%1 = OpFunction %11 None %12 +%23 = OpLabel +%24 = OpVariable %15 Function +%25 = OpInBoundsAccessChain %16 %2 %17 +%26 = OpArrayLength %7 %2 0 +%27 = OpLoad %9 %3 +%28 = OpCompositeExtract %7 %27 0 +%29 = OpCompositeExtract %7 %27 1 +%30 = OpCompositeExtract %7 %27 2 +%31 = OpInBoundsAccessChain %18 %24 %17 +OpStore %31 %28 +%32 = OpInBoundsAccessChain %18 %24 %19 +OpStore %32 %29 +%33 = OpInBoundsAccessChain %18 %24 %20 +OpStore %33 %30 +%34 = OpULessThan %21 %17 %13 +OpNoLine +OpSelectionMerge %35 None +OpBranchConditional %34 %36 %37 +%36 = OpLabel +OpBranch %35 +%37 = OpLabel +OpReturn +%35 = OpLabel +%38 = OpInBoundsAccessChain %18 %24 %17 +%39 = OpLoad %7 %38 +%40 = OpULessThan %21 %17 %26 +OpNoLine +OpSelectionMerge %41 None +OpBranchConditional %40 %42 %43 +%42 = OpLabel +OpBranch %41 +%43 = OpLabel +OpReturn +%41 = OpLabel +%44 = OpInBoundsAccessChain %22 %25 %17 +OpStore %44 %39 +%45 = OpULessThan %21 %19 %13 +OpNoLine +OpSelectionMerge %46 None +OpBranchConditional %45 %47 %48 +%47 = OpLabel +OpBranch %46 +%48 = OpLabel +OpReturn +%46 = OpLabel +%49 = OpInBoundsAccessChain %18 %24 %19 +%50 = OpLoad %7 %49 +%51 = OpULessThan %21 %19 %26 +OpNoLine +OpSelectionMerge %52 None +OpBranchConditional %51 %53 %54 +%53 = OpLabel +OpBranch %52 +%54 = OpLabel +OpReturn +%52 = OpLabel +%55 = OpInBoundsAccessChain %22 %25 %19 +OpStore %55 %50 +%56 = OpULessThan %21 %20 %13 +OpNoLine +OpSelectionMerge %57 None +OpBranchConditional %56 %58 %59 +%58 = OpLabel +OpBranch %57 +%59 = OpLabel +OpReturn +%57 = OpLabel +%60 = OpInBoundsAccessChain %18 %24 %20 +%61 = OpLoad %7 %60 +%62 = OpULessThan %21 %20 %26 +OpNoLine +OpSelectionMerge %63 None +OpBranchConditional %62 %64 %65 +%64 = OpLabel +OpBranch %63 +%65 = OpLabel +OpReturn +%63 = OpLabel +%66 = OpInBoundsAccessChain %22 %25 %20 +OpStore %66 %61 +OpNoLine +OpReturn +OpFunctionEnd