Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion crates/rustc_codegen_spirv/src/builder/builder_methods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2423,7 +2423,32 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
"pointercast called on non-pointer dest type: {other:?}"
)),
};
let dest_pointee_size = self.lookup_type(dest_pointee).sizeof(self);

let dst_pointee_ty = self.lookup_type(dest_pointee);
let dest_pointee_size = dst_pointee_ty.sizeof(self);
let src_pointee_ty = self.lookup_type(ptr_pointee);

// *[T; N] -> *RuntimeArray<T>
if let SpirvType::Array {
element: elem_ty, ..
} = src_pointee_ty
&& let SpirvType::RuntimeArray {
element: rt_elem_ty,
} = dst_pointee_ty
&& elem_ty == rt_elem_ty
{
let zero = self.constant_u32(self.span(), 0).def(self);
let elem_ptr_ty = self.type_ptr_to(elem_ty);
let elem_ptr = self
.emit()
.in_bounds_access_chain(elem_ptr_ty, None, ptr.def(self), [zero])
.unwrap();
return self
.emit()
.bitcast(dest_ty, None, elem_ptr)
.unwrap()
.with_type(dest_ty);
}

if let Some((indices, _)) = self.recover_access_chain_from_offset(
ptr_pointee,
Expand Down
1 change: 1 addition & 0 deletions crates/rustc_codegen_spirv/src/linker/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -660,6 +660,7 @@ pub fn link(
peephole_opts::composite_construct(&types, func);
peephole_opts::vector_ops(output.header.as_mut().unwrap(), &types, func);
peephole_opts::bool_fusion(output.header.as_mut().unwrap(), &types, func);
peephole_opts::fold_array_bitcast_access_chain(&types, func);
}
}

Expand Down
102 changes: 102 additions & 0 deletions crates/rustc_codegen_spirv/src/linker/peephole_opts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -680,3 +680,105 @@ pub fn fold_load_from_constant_variable(module: &mut Module) {
}
}
}

/// Eliminate the `OpBitcast` that arises from `*[T; N] → *RuntimeArray<T>` pointer casts.
///
/// When a local array is coerced to a slice (`&[T;N] as &[T]`), codegen emits:
/// ```text
/// %elem0 = OpInBoundsAccessChain %arr 0 ; pointer to arr[0]
/// %rta = OpBitcast %elem0 ; *RuntimeArray<T> (invalid in Logical SPIR-V)
/// ```
/// After inlining the slice-taking function, indexing becomes:
/// ```text
/// %ei = OpInBoundsAccessChain %rta i ; data[i]
///
pub fn fold_array_bitcast_access_chain(
types: &FxHashMap<Word, Instruction>,
function: &mut Function,
) {
let func_defs: FxHashMap<Word, Instruction> = function
.all_inst_iter()
.filter_map(|inst| Some((inst.result_id?, inst.clone())))
.collect();

// look up an ID in either function-local defs or module-level types/globals.
let lookup =
|id: Word| -> Option<&Instruction> { func_defs.get(&id).or_else(|| types.get(&id)) };

for block in &mut function.blocks {
for inst in &mut block.instructions {
if !matches!(inst.class.opcode, Op::AccessChain | Op::InBoundsAccessChain) {
continue;
}
if inst.operands.is_empty() {
continue;
}
let base_id = inst.operands[0].unwrap_id_ref();

// base must be an OpBitcast
let Some(bitcast) = lookup(base_id) else {
continue;
};
if bitcast.class.opcode != Op::Bitcast {
continue;
}

// bitcast result type must be *SC RuntimeArray<T>
let Some(bitcast_dst_ptr) = lookup(bitcast.result_type.unwrap()) else {
continue;
};
if bitcast_dst_ptr.class.opcode != Op::TypePointer {
continue;
}
let rta_type_id = bitcast_dst_ptr.operands[1].unwrap_id_ref();
let Some(rta_ty) = lookup(rta_type_id) else {
continue;
};
if rta_ty.class.opcode != Op::TypeRuntimeArray {
continue;
}
let rta_elem_ty = rta_ty.operands[0].unwrap_id_ref();

// bitcast source must be OpInBoundsAccessChain(arr, 0)
let bitcast_src_id = bitcast.operands[0].unwrap_id_ref();
let Some(inner_ac) = lookup(bitcast_src_id) else {
continue;
};
if inner_ac.class.opcode != Op::InBoundsAccessChain {
continue;
}
// Exactly one index operand
if inner_ac.operands.len() != 2 {
continue;
}
// That index must be the constant 0
let idx0_id = inner_ac.operands[1].unwrap_id_ref();
let Some(idx0) = lookup(idx0_id) else {
continue;
};
if idx0.class.opcode != Op::Constant {
continue;
}
if !matches!(idx0.operands[0], Operand::LiteralBit32(0)) {
continue;
}

// inner AccessChain result type must be *SC T where T == rta_elem_ty
let Some(inner_dst_ptr) = lookup(inner_ac.result_type.unwrap()) else {
continue;
};
if inner_dst_ptr.class.opcode != Op::TypePointer {
continue;
}
let elem_ty = inner_dst_ptr.operands[1].unwrap_id_ref();
if elem_ty != rta_elem_ty {
continue;
}

// AccessChain(Bitcast(InBoundsAccessChain(arr, 0)), i)
// Replace base with arr — the dead bitcast and intermediate AC are cleaned by DCE.
let arr_id = inner_ac.operands[0].unwrap_id_ref();
inst.operands[0] = Operand::IdRef(arr_id);
}
}
}
33 changes: 33 additions & 0 deletions tests/compiletests/ui/lang/core/array-slice-cast.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// build-pass
// compile-flags: -C llvm-args=--disassemble
// normalize-stderr-test "OpSource .*\n" -> ""
// normalize-stderr-test "OpLine .*\n" -> ""
// normalize-stderr-test "%\d+ = OpString .*\n" -> ""
// normalize-stderr-test "^(; .*\n)*" -> ""
// normalize-stderr-test "OpCapability VulkanMemoryModel\n" -> ""
// normalize-stderr-test "OpMemoryModel Logical Vulkan" -> "OpMemoryModel Logical Simple"
// ignore-spv1.0
// ignore-spv1.1
// ignore-spv1.2
// ignore-spv1.3
// ignore-spv1.4
// ignore-spv1.5
// ignore-spv1.6
// ignore-vulkan1.0
// ignore-vulkan1.1
use spirv_std::spirv;

fn do_work(data: &[u32], slab: &mut [u32]) {
slab[0] = data[0];
slab[1] = data[1];
slab[2] = data[2];
}

#[spirv(compute(threads(64)))]
pub fn compute_shader(
#[spirv(storage_buffer, descriptor_set = 0, binding = 0)] slab: &mut [u32],
#[spirv(global_invocation_id)] global_id: glam::UVec3,
) {
let data = [global_id.x, global_id.y, global_id.z];
do_work(&data, slab);
}
116 changes: 116 additions & 0 deletions tests/compiletests/ui/lang/core/array-slice-cast.stderr
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
OpCapability Shader
OpMemoryModel Logical Simple
OpEntryPoint GLCompute %1 "compute_shader" %2 %3
OpExecutionMode %1 LocalSize 64 1 1
OpName %2 "slab"
OpName %3 "global_id"
OpDecorate %5 ArrayStride 4
OpDecorate %6 Block
OpMemberDecorate %6 0 Offset 0
OpDecorate %2 Binding 0
OpDecorate %2 DescriptorSet 0
OpDecorate %3 BuiltIn GlobalInvocationId
%7 = OpTypeInt 32 0
%5 = OpTypeRuntimeArray %7
%6 = OpTypeStruct %5
%8 = OpTypePointer StorageBuffer %6
%9 = OpTypeVector %7 3
%10 = OpTypePointer Input %9
%11 = OpTypeVoid
%12 = OpTypeFunction %11
%13 = OpConstant %7 3
%14 = OpTypeArray %7 %13
%15 = OpTypePointer Function %14
%16 = OpTypePointer StorageBuffer %5
%2 = OpVariable %8 StorageBuffer
%17 = OpConstant %7 0
%3 = OpVariable %10 Input
%18 = OpTypePointer Function %7
%19 = OpConstant %7 1
%20 = OpConstant %7 2
%21 = OpTypeBool
%22 = OpTypePointer StorageBuffer %7
%1 = OpFunction %11 None %12
%23 = OpLabel
%24 = OpVariable %15 Function
%25 = OpInBoundsAccessChain %16 %2 %17
%26 = OpArrayLength %7 %2 0
%27 = OpLoad %9 %3
%28 = OpCompositeExtract %7 %27 0
%29 = OpCompositeExtract %7 %27 1
%30 = OpCompositeExtract %7 %27 2
%31 = OpInBoundsAccessChain %18 %24 %17
OpStore %31 %28
%32 = OpInBoundsAccessChain %18 %24 %19
OpStore %32 %29
%33 = OpInBoundsAccessChain %18 %24 %20
OpStore %33 %30
%34 = OpULessThan %21 %17 %13
OpNoLine
OpSelectionMerge %35 None
OpBranchConditional %34 %36 %37
%36 = OpLabel
OpBranch %35
%37 = OpLabel
OpReturn
%35 = OpLabel
%38 = OpInBoundsAccessChain %18 %24 %17
%39 = OpLoad %7 %38
%40 = OpULessThan %21 %17 %26
OpNoLine
OpSelectionMerge %41 None
OpBranchConditional %40 %42 %43
%42 = OpLabel
OpBranch %41
%43 = OpLabel
OpReturn
%41 = OpLabel
%44 = OpInBoundsAccessChain %22 %25 %17
OpStore %44 %39
%45 = OpULessThan %21 %19 %13
OpNoLine
OpSelectionMerge %46 None
OpBranchConditional %45 %47 %48
%47 = OpLabel
OpBranch %46
%48 = OpLabel
OpReturn
%46 = OpLabel
%49 = OpInBoundsAccessChain %18 %24 %19
%50 = OpLoad %7 %49
%51 = OpULessThan %21 %19 %26
OpNoLine
OpSelectionMerge %52 None
OpBranchConditional %51 %53 %54
%53 = OpLabel
OpBranch %52
%54 = OpLabel
OpReturn
%52 = OpLabel
%55 = OpInBoundsAccessChain %22 %25 %19
OpStore %55 %50
%56 = OpULessThan %21 %20 %13
OpNoLine
OpSelectionMerge %57 None
OpBranchConditional %56 %58 %59
%58 = OpLabel
OpBranch %57
%59 = OpLabel
OpReturn
%57 = OpLabel
%60 = OpInBoundsAccessChain %18 %24 %20
%61 = OpLoad %7 %60
%62 = OpULessThan %21 %20 %26
OpNoLine
OpSelectionMerge %63 None
OpBranchConditional %62 %64 %65
%64 = OpLabel
OpBranch %63
%65 = OpLabel
OpReturn
%63 = OpLabel
%66 = OpInBoundsAccessChain %22 %25 %20
OpStore %66 %61
OpNoLine
OpReturn
OpFunctionEnd
Loading