From e46d09096431acd0370bfda411aef0b7de54c663 Mon Sep 17 00:00:00 2001 From: 39ali Date: Sat, 18 Apr 2026 13:18:27 +0300 Subject: [PATCH 1/2] fix array slice cast --- .../src/builder/builder_methods.rs | 23 +++- crates/rustc_codegen_spirv/src/linker/mod.rs | 1 + .../src/linker/peephole_opts.rs | 91 ++++++++++++++ .../ui/lang/core/array-slice-cast.rs | 33 +++++ .../ui/lang/core/array-slice-cast.stderr | 116 ++++++++++++++++++ 5 files changed, 263 insertions(+), 1 deletion(-) create mode 100644 tests/compiletests/ui/lang/core/array-slice-cast.rs create mode 100644 tests/compiletests/ui/lang/core/array-slice-cast.stderr diff --git a/crates/rustc_codegen_spirv/src/builder/builder_methods.rs b/crates/rustc_codegen_spirv/src/builder/builder_methods.rs index d86db1cbd00..7eb272b4194 100644 --- a/crates/rustc_codegen_spirv/src/builder/builder_methods.rs +++ b/crates/rustc_codegen_spirv/src/builder/builder_methods.rs @@ -2423,7 +2423,28 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { "pointercast called on non-pointer dest type: {other:?}" )), }; - let dest_pointee_size = self.lookup_type(dest_pointee).sizeof(self); + + let dst_pointee_ty = self.lookup_type(dest_pointee); + let dest_pointee_size = dst_pointee_ty.sizeof(self); + let src_pointee_ty = self.lookup_type(ptr_pointee); + + // *[T; N] -> *RuntimeArray + if let SpirvType::Array { element: elem_ty, .. } = src_pointee_ty + && let SpirvType::RuntimeArray { element: rt_elem_ty } = dst_pointee_ty + && elem_ty == rt_elem_ty + { + let zero = self.constant_u32(self.span(), 0).def(self); + let elem_ptr_ty = self.type_ptr_to(elem_ty); + let elem_ptr = self + .emit() + .in_bounds_access_chain(elem_ptr_ty, None, ptr.def(self), [zero]) + .unwrap(); + return self + .emit() + .bitcast(dest_ty, None, elem_ptr) + .unwrap() + .with_type(dest_ty); + } if let Some((indices, _)) = self.recover_access_chain_from_offset( ptr_pointee, diff --git a/crates/rustc_codegen_spirv/src/linker/mod.rs b/crates/rustc_codegen_spirv/src/linker/mod.rs index 98e076bf5fe..4aba5b80858 100644 --- a/crates/rustc_codegen_spirv/src/linker/mod.rs +++ b/crates/rustc_codegen_spirv/src/linker/mod.rs @@ -660,6 +660,7 @@ pub fn link( peephole_opts::composite_construct(&types, func); peephole_opts::vector_ops(output.header.as_mut().unwrap(), &types, func); peephole_opts::bool_fusion(output.header.as_mut().unwrap(), &types, func); + peephole_opts::fold_array_bitcast_access_chain(&types, func); } } diff --git a/crates/rustc_codegen_spirv/src/linker/peephole_opts.rs b/crates/rustc_codegen_spirv/src/linker/peephole_opts.rs index 6162d55a377..6a702ed08e6 100644 --- a/crates/rustc_codegen_spirv/src/linker/peephole_opts.rs +++ b/crates/rustc_codegen_spirv/src/linker/peephole_opts.rs @@ -680,3 +680,94 @@ pub fn fold_load_from_constant_variable(module: &mut Module) { } } } + +/// Eliminate the `OpBitcast` that arises from `*[T; N] → *RuntimeArray` pointer casts. +/// +/// When a local array is coerced to a slice (`&[T;N] as &[T]`), codegen emits: +/// ```text +/// %elem0 = OpInBoundsAccessChain %arr 0 ; pointer to arr[0] +/// %rta = OpBitcast %elem0 ; *RuntimeArray (invalid in Logical SPIR-V) +/// ``` +/// After inlining the slice-taking function, indexing becomes: +/// ```text +/// %ei = OpInBoundsAccessChain %rta i ; data[i] +/// +pub fn fold_array_bitcast_access_chain( + types: &FxHashMap, + function: &mut Function, +) { + let func_defs: FxHashMap = function + .all_inst_iter() + .filter_map(|inst| Some((inst.result_id?, inst.clone()))) + .collect(); + + // look up an ID in either function-local defs or module-level types/globals. + let lookup = |id: Word| -> Option<&Instruction> { + func_defs.get(&id).or_else(|| types.get(&id)) + }; + + for block in &mut function.blocks { + for inst in &mut block.instructions { + if !matches!(inst.class.opcode, Op::AccessChain | Op::InBoundsAccessChain) { + continue; + } + if inst.operands.is_empty() { + continue; + } + let base_id = inst.operands[0].unwrap_id_ref(); + + // base must be an OpBitcast + let Some(bitcast) = lookup(base_id) else { continue }; + if bitcast.class.opcode != Op::Bitcast { + continue; + } + + // bitcast result type must be *SC RuntimeArray + let Some(bitcast_dst_ptr) = lookup(bitcast.result_type.unwrap()) else { continue }; + if bitcast_dst_ptr.class.opcode != Op::TypePointer { + continue; + } + let rta_type_id = bitcast_dst_ptr.operands[1].unwrap_id_ref(); + let Some(rta_ty) = lookup(rta_type_id) else { continue }; + if rta_ty.class.opcode != Op::TypeRuntimeArray { + continue; + } + let rta_elem_ty = rta_ty.operands[0].unwrap_id_ref(); + + // bitcast source must be OpInBoundsAccessChain(arr, 0) + let bitcast_src_id = bitcast.operands[0].unwrap_id_ref(); + let Some(inner_ac) = lookup(bitcast_src_id) else { continue }; + if inner_ac.class.opcode != Op::InBoundsAccessChain { + continue; + } + // Exactly one index operand + if inner_ac.operands.len() != 2 { + continue; + } + // That index must be the constant 0 + let idx0_id = inner_ac.operands[1].unwrap_id_ref(); + let Some(idx0) = lookup(idx0_id) else { continue }; + if idx0.class.opcode != Op::Constant { + continue; + } + if !matches!(idx0.operands[0], Operand::LiteralBit32(0)) { + continue; + } + + // inner AccessChain result type must be *SC T where T == rta_elem_ty + let Some(inner_dst_ptr) = lookup(inner_ac.result_type.unwrap()) else { continue }; + if inner_dst_ptr.class.opcode != Op::TypePointer { + continue; + } + let elem_ty = inner_dst_ptr.operands[1].unwrap_id_ref(); + if elem_ty != rta_elem_ty { + continue; + } + + // AccessChain(Bitcast(InBoundsAccessChain(arr, 0)), i) + // Replace base with arr — the dead bitcast and intermediate AC are cleaned by DCE. + let arr_id = inner_ac.operands[0].unwrap_id_ref(); + inst.operands[0] = Operand::IdRef(arr_id); + } + } +} diff --git a/tests/compiletests/ui/lang/core/array-slice-cast.rs b/tests/compiletests/ui/lang/core/array-slice-cast.rs new file mode 100644 index 00000000000..8a7f039fb66 --- /dev/null +++ b/tests/compiletests/ui/lang/core/array-slice-cast.rs @@ -0,0 +1,33 @@ +// build-pass +// compile-flags: -C llvm-args=--disassemble +// normalize-stderr-test "OpSource .*\n" -> "" +// normalize-stderr-test "OpLine .*\n" -> "" +// normalize-stderr-test "%\d+ = OpString .*\n" -> "" +// normalize-stderr-test "^(; .*\n)*" -> "" +// normalize-stderr-test "OpCapability VulkanMemoryModel\n" -> "" +// normalize-stderr-test "OpMemoryModel Logical Vulkan" -> "OpMemoryModel Logical Simple" +// ignore-spv1.0 +// ignore-spv1.1 +// ignore-spv1.2 +// ignore-spv1.3 +// ignore-spv1.4 +// ignore-spv1.5 +// ignore-spv1.6 +// ignore-vulkan1.0 +// ignore-vulkan1.1 +use spirv_std::spirv; + +fn do_work(data: &[u32], slab: &mut [u32]) { + slab[0] = data[0]; + slab[1] = data[1]; + slab[2] = data[2]; +} + +#[spirv(compute(threads(64)))] +pub fn compute_shader( + #[spirv(storage_buffer, descriptor_set = 0, binding = 0)] slab: &mut [u32], + #[spirv(global_invocation_id)] global_id: glam::UVec3, +) { + let data = [global_id.x, global_id.y, global_id.z]; + do_work(&data, slab); +} diff --git a/tests/compiletests/ui/lang/core/array-slice-cast.stderr b/tests/compiletests/ui/lang/core/array-slice-cast.stderr new file mode 100644 index 00000000000..9b7c7b3782a --- /dev/null +++ b/tests/compiletests/ui/lang/core/array-slice-cast.stderr @@ -0,0 +1,116 @@ +OpCapability Shader +OpMemoryModel Logical Simple +OpEntryPoint GLCompute %1 "compute_shader" %2 %3 +OpExecutionMode %1 LocalSize 64 1 1 +OpName %2 "slab" +OpName %3 "global_id" +OpDecorate %5 ArrayStride 4 +OpDecorate %6 Block +OpMemberDecorate %6 0 Offset 0 +OpDecorate %2 Binding 0 +OpDecorate %2 DescriptorSet 0 +OpDecorate %3 BuiltIn GlobalInvocationId +%7 = OpTypeInt 32 0 +%5 = OpTypeRuntimeArray %7 +%6 = OpTypeStruct %5 +%8 = OpTypePointer StorageBuffer %6 +%9 = OpTypeVector %7 3 +%10 = OpTypePointer Input %9 +%11 = OpTypeVoid +%12 = OpTypeFunction %11 +%13 = OpConstant %7 3 +%14 = OpTypeArray %7 %13 +%15 = OpTypePointer Function %14 +%16 = OpTypePointer StorageBuffer %5 +%2 = OpVariable %8 StorageBuffer +%17 = OpConstant %7 0 +%3 = OpVariable %10 Input +%18 = OpTypePointer Function %7 +%19 = OpConstant %7 1 +%20 = OpConstant %7 2 +%21 = OpTypeBool +%22 = OpTypePointer StorageBuffer %7 +%1 = OpFunction %11 None %12 +%23 = OpLabel +%24 = OpVariable %15 Function +%25 = OpInBoundsAccessChain %16 %2 %17 +%26 = OpArrayLength %7 %2 0 +%27 = OpLoad %9 %3 +%28 = OpCompositeExtract %7 %27 0 +%29 = OpCompositeExtract %7 %27 1 +%30 = OpCompositeExtract %7 %27 2 +%31 = OpInBoundsAccessChain %18 %24 %17 +OpStore %31 %28 +%32 = OpInBoundsAccessChain %18 %24 %19 +OpStore %32 %29 +%33 = OpInBoundsAccessChain %18 %24 %20 +OpStore %33 %30 +%34 = OpULessThan %21 %17 %13 +OpNoLine +OpSelectionMerge %35 None +OpBranchConditional %34 %36 %37 +%36 = OpLabel +OpBranch %35 +%37 = OpLabel +OpReturn +%35 = OpLabel +%38 = OpInBoundsAccessChain %18 %24 %17 +%39 = OpLoad %7 %38 +%40 = OpULessThan %21 %17 %26 +OpNoLine +OpSelectionMerge %41 None +OpBranchConditional %40 %42 %43 +%42 = OpLabel +OpBranch %41 +%43 = OpLabel +OpReturn +%41 = OpLabel +%44 = OpInBoundsAccessChain %22 %25 %17 +OpStore %44 %39 +%45 = OpULessThan %21 %19 %13 +OpNoLine +OpSelectionMerge %46 None +OpBranchConditional %45 %47 %48 +%47 = OpLabel +OpBranch %46 +%48 = OpLabel +OpReturn +%46 = OpLabel +%49 = OpInBoundsAccessChain %18 %24 %19 +%50 = OpLoad %7 %49 +%51 = OpULessThan %21 %19 %26 +OpNoLine +OpSelectionMerge %52 None +OpBranchConditional %51 %53 %54 +%53 = OpLabel +OpBranch %52 +%54 = OpLabel +OpReturn +%52 = OpLabel +%55 = OpInBoundsAccessChain %22 %25 %19 +OpStore %55 %50 +%56 = OpULessThan %21 %20 %13 +OpNoLine +OpSelectionMerge %57 None +OpBranchConditional %56 %58 %59 +%58 = OpLabel +OpBranch %57 +%59 = OpLabel +OpReturn +%57 = OpLabel +%60 = OpInBoundsAccessChain %18 %24 %20 +%61 = OpLoad %7 %60 +%62 = OpULessThan %21 %20 %26 +OpNoLine +OpSelectionMerge %63 None +OpBranchConditional %62 %64 %65 +%64 = OpLabel +OpBranch %63 +%65 = OpLabel +OpReturn +%63 = OpLabel +%66 = OpInBoundsAccessChain %22 %25 %20 +OpStore %66 %61 +OpNoLine +OpReturn +OpFunctionEnd From 166d9dd9c8fb77518b41b2f99e5a5b521029dd69 Mon Sep 17 00:00:00 2001 From: 39ali Date: Sat, 18 Apr 2026 13:28:05 +0300 Subject: [PATCH 2/2] fmt --- .../src/builder/builder_methods.rs | 8 +++-- .../src/linker/peephole_opts.rs | 31 +++++++++++++------ 2 files changed, 27 insertions(+), 12 deletions(-) diff --git a/crates/rustc_codegen_spirv/src/builder/builder_methods.rs b/crates/rustc_codegen_spirv/src/builder/builder_methods.rs index 7eb272b4194..0d5341d3e89 100644 --- a/crates/rustc_codegen_spirv/src/builder/builder_methods.rs +++ b/crates/rustc_codegen_spirv/src/builder/builder_methods.rs @@ -2429,8 +2429,12 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { let src_pointee_ty = self.lookup_type(ptr_pointee); // *[T; N] -> *RuntimeArray - if let SpirvType::Array { element: elem_ty, .. } = src_pointee_ty - && let SpirvType::RuntimeArray { element: rt_elem_ty } = dst_pointee_ty + if let SpirvType::Array { + element: elem_ty, .. + } = src_pointee_ty + && let SpirvType::RuntimeArray { + element: rt_elem_ty, + } = dst_pointee_ty && elem_ty == rt_elem_ty { let zero = self.constant_u32(self.span(), 0).def(self); diff --git a/crates/rustc_codegen_spirv/src/linker/peephole_opts.rs b/crates/rustc_codegen_spirv/src/linker/peephole_opts.rs index 6a702ed08e6..e5037d7a7c3 100644 --- a/crates/rustc_codegen_spirv/src/linker/peephole_opts.rs +++ b/crates/rustc_codegen_spirv/src/linker/peephole_opts.rs @@ -691,7 +691,7 @@ pub fn fold_load_from_constant_variable(module: &mut Module) { /// After inlining the slice-taking function, indexing becomes: /// ```text /// %ei = OpInBoundsAccessChain %rta i ; data[i] -/// +/// pub fn fold_array_bitcast_access_chain( types: &FxHashMap, function: &mut Function, @@ -702,9 +702,8 @@ pub fn fold_array_bitcast_access_chain( .collect(); // look up an ID in either function-local defs or module-level types/globals. - let lookup = |id: Word| -> Option<&Instruction> { - func_defs.get(&id).or_else(|| types.get(&id)) - }; + let lookup = + |id: Word| -> Option<&Instruction> { func_defs.get(&id).or_else(|| types.get(&id)) }; for block in &mut function.blocks { for inst in &mut block.instructions { @@ -717,18 +716,24 @@ pub fn fold_array_bitcast_access_chain( let base_id = inst.operands[0].unwrap_id_ref(); // base must be an OpBitcast - let Some(bitcast) = lookup(base_id) else { continue }; + let Some(bitcast) = lookup(base_id) else { + continue; + }; if bitcast.class.opcode != Op::Bitcast { continue; } // bitcast result type must be *SC RuntimeArray - let Some(bitcast_dst_ptr) = lookup(bitcast.result_type.unwrap()) else { continue }; + let Some(bitcast_dst_ptr) = lookup(bitcast.result_type.unwrap()) else { + continue; + }; if bitcast_dst_ptr.class.opcode != Op::TypePointer { continue; } let rta_type_id = bitcast_dst_ptr.operands[1].unwrap_id_ref(); - let Some(rta_ty) = lookup(rta_type_id) else { continue }; + let Some(rta_ty) = lookup(rta_type_id) else { + continue; + }; if rta_ty.class.opcode != Op::TypeRuntimeArray { continue; } @@ -736,7 +741,9 @@ pub fn fold_array_bitcast_access_chain( // bitcast source must be OpInBoundsAccessChain(arr, 0) let bitcast_src_id = bitcast.operands[0].unwrap_id_ref(); - let Some(inner_ac) = lookup(bitcast_src_id) else { continue }; + let Some(inner_ac) = lookup(bitcast_src_id) else { + continue; + }; if inner_ac.class.opcode != Op::InBoundsAccessChain { continue; } @@ -746,7 +753,9 @@ pub fn fold_array_bitcast_access_chain( } // That index must be the constant 0 let idx0_id = inner_ac.operands[1].unwrap_id_ref(); - let Some(idx0) = lookup(idx0_id) else { continue }; + let Some(idx0) = lookup(idx0_id) else { + continue; + }; if idx0.class.opcode != Op::Constant { continue; } @@ -755,7 +764,9 @@ pub fn fold_array_bitcast_access_chain( } // inner AccessChain result type must be *SC T where T == rta_elem_ty - let Some(inner_dst_ptr) = lookup(inner_ac.result_type.unwrap()) else { continue }; + let Some(inner_dst_ptr) = lookup(inner_ac.result_type.unwrap()) else { + continue; + }; if inner_dst_ptr.class.opcode != Op::TypePointer { continue; }