-
Notifications
You must be signed in to change notification settings - Fork 11
/
link_input_outputs.rs
347 lines (294 loc) · 12 KB
/
link_input_outputs.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
use librashader_common::map::FastHashMap;
use rspirv::dr::{Builder, Module, Operand};
use spirv::{Decoration, Op, StorageClass};
/// Do DCE on inputs of the fragment shader, then
/// link by downgrading outputs of unused fragment inputs
/// to global variables on the vertex shader.
pub struct LinkInputs<'a> {
pub frag_builder: &'a mut Builder,
pub vert_builder: &'a mut Builder,
// binding -> ID
pub outputs: FastHashMap<u32, spirv::Word>,
// id -> binding
pub inputs_to_remove: FastHashMap<spirv::Word, u32>,
}
impl<'a> LinkInputs<'a> {
/// Get the value of the location of the inout in the module
fn find_location(module: &Module, id: spirv::Word) -> Option<u32> {
module.annotations.iter().find_map(|op| {
if op.class.opcode != Op::Decorate {
return None;
}
let Some(Operand::Decoration(Decoration::Location)) = op.operands.get(1) else {
return None;
};
let Some(&Operand::IdRef(target)) = op.operands.get(0) else {
return None;
};
if target != id {
return None;
}
let Some(&Operand::LiteralBit32(binding)) = op.operands.get(2) else {
return None;
};
return Some(binding);
})
}
pub fn new(vert: &'a mut Builder, frag: &'a mut Builder, keep_if_bound: bool) -> Self {
let mut outputs = FastHashMap::default();
let mut inputs_to_remove = FastHashMap::default();
let mut inputs = FastHashMap::default();
for global in frag.module_ref().types_global_values.iter() {
if global.class.opcode == spirv::Op::Variable
&& global.operands[0] == Operand::StorageClass(StorageClass::Input)
{
if let Some(id) = global.result_id {
let Some(location) = Self::find_location(frag.module_ref(), id) else {
continue;
};
inputs_to_remove.insert(id, location);
inputs.insert(location, id);
}
}
}
for global in vert.module_ref().types_global_values.iter() {
if global.class.opcode == spirv::Op::Variable
&& global.operands[0] == Operand::StorageClass(StorageClass::Output)
{
if let Some(id) = global.result_id {
let Some(location) = Self::find_location(vert.module_ref(), id) else {
continue;
};
// Add to list of outputs
outputs.insert(location, id);
// Keep the input, if it is bound to both stages.Otherwise, do DCE analysis on
// the input, and remove it regardless if bound, if unused by the fragment stage.
if keep_if_bound {
if let Some(frag_ref) = inputs.get(&location) {
// if something is bound to the same location in the vertex shader,
// we're good.
inputs_to_remove.remove(&frag_ref);
}
}
}
}
}
Self {
frag_builder: frag,
vert_builder: vert,
outputs,
inputs_to_remove,
}
}
pub fn do_pass(&mut self) {
self.trim_inputs();
self.downgrade_outputs();
self.put_vertex_variables_to_end();
}
fn put_vertex_variables_to_end(&mut self) {
// this is easier than doing proper topo sort.
// we need it so that all type definitions are valid before
// being referred to by a variable.
let mut vars = Vec::new();
self.vert_builder
.module_mut()
.types_global_values
.retain(|instr| {
if instr.class.opcode == spirv::Op::Variable {
vars.push(instr.clone());
return false;
};
true
});
self.vert_builder
.module_mut()
.types_global_values
.append(&mut vars);
}
/// Downgrade dead inputs corresponding to outputs to global variables, keeping existing mappings.
fn downgrade_outputs(&mut self) {
let dead_outputs = self
.inputs_to_remove
.values()
.filter_map(|i| self.outputs.get(i).cloned().map(|w| (w, ())))
.collect::<FastHashMap<spirv::Word, ()>>();
let mut pointer_types_to_downgrade = FastHashMap::default();
// Map from Pointer type to pointee
let mut pointer_type_pointee = FastHashMap::default();
// Map from StorageClass Output to StorageClass Private
let mut downgraded_pointer_types = FastHashMap::default();
// First collect all the pointer types that are needed for dead outputs.
for global in self.vert_builder.module_ref().types_global_values.iter() {
if global.class.opcode != spirv::Op::Variable
|| global.operands[0] != Operand::StorageClass(StorageClass::Output)
{
continue;
}
if let Some(id) = global.result_id {
if !dead_outputs.contains_key(&id) {
continue;
}
if let Some(result_type) = global.result_type {
pointer_types_to_downgrade.insert(result_type, ());
}
}
}
// Collect all the pointee types of pointer types to downgrade
for global in self.vert_builder.module_ref().types_global_values.iter() {
if global.class.opcode != spirv::Op::TypePointer
|| global.operands[0] != Operand::StorageClass(StorageClass::Output)
{
continue;
}
if let Some(id) = global.result_id {
if !pointer_types_to_downgrade.contains_key(&id) {
continue;
}
let Some(pointee_type) = global.operands[1].id_ref_any() else {
continue;
};
pointer_type_pointee.insert(id, pointee_type);
}
}
// Create pointer types for everything we saw above with Private storage class.
// We don't have to deal with OpTypeForwardPointer, because PhysicalStorageBuffer
// is not valid in slang shaders, and we're only working with Vulkan inputs.
for (pointer_type, pointee_type) in pointer_type_pointee.iter() {
// Create a new private type
let private_pointer_type =
self.vert_builder
.type_pointer(None, StorageClass::Private, *pointee_type);
// Add it to the mapping
downgraded_pointer_types.insert(pointer_type, private_pointer_type);
}
// Downgrade the OpVariable storage class and reassign the types.
for global in self
.vert_builder
.module_mut()
.types_global_values
.iter_mut()
{
if global.class.opcode != spirv::Op::Variable
|| global.operands[0] != Operand::StorageClass(StorageClass::Output)
{
continue;
}
if let Some(id) = global.result_id {
if !dead_outputs.contains_key(&id) {
continue;
}
// downgrade the OpVariable if it's in dead_outputs
global.operands[0] = Operand::StorageClass(StorageClass::Private);
// Get the result type. If there's no result type it's invalid anyways
// so it doesn't matter that we downgraded early (better downgraded than unmatched)
let Some(result_type) = &mut global.result_type else {
continue;
};
let Some(new_type) = downgraded_pointer_types.get(&*result_type) else {
// We should have created one above.
continue;
};
// Set the type of the OpVariable to the same type with Private storageclass.
*result_type = *new_type;
}
}
// Strip decorations of downgraded variables.
self.vert_builder.module_mut().annotations.retain_mut(|op| {
if op.class.opcode != Op::Decorate {
return true;
}
let Some(Operand::Decoration(Decoration::Location)) = op.operands.get(1) else {
return true;
};
let Some(&Operand::IdRef(target)) = op.operands.get(0) else {
return true;
};
// If target is in dead outputs, then don't keep it.
!dead_outputs.contains_key(&target)
});
for entry_point in self.vert_builder.module_mut().entry_points.iter_mut() {
let mut index = 0;
entry_point.operands.retain(|s| {
// Skip the execution mode, entry point reference, and name.
if index < 3 {
index += 1;
return true;
}
index += 1;
// Ignore any non-IdRef
let Operand::IdRef(id_ref) = s else {
return true;
};
// If the entry point contains a dead outputs, remove it from the interface.
!dead_outputs.contains_key(id_ref)
});
}
}
// Trim unused fragment shader inputs
fn trim_inputs(&mut self) {
let functions = &self.frag_builder.module_ref().functions;
// literally if it has any reference at all we can keep it
for function in functions {
for param in &function.parameters {
for op in ¶m.operands {
if let Some(word) = op.id_ref_any() {
if self.inputs_to_remove.contains_key(&word) {
self.inputs_to_remove.remove(&word);
}
}
}
}
for block in &function.blocks {
for inst in &block.instructions {
for op in &inst.operands {
if let Some(word) = op.id_ref_any() {
if self.inputs_to_remove.contains_key(&word) {
self.inputs_to_remove.remove(&word);
}
}
}
}
}
}
// ok well guess we dont
self.frag_builder.module_mut().debug_names.retain(|instr| {
for op in &instr.operands {
if let Some(word) = op.id_ref_any() {
if self.inputs_to_remove.contains_key(&word) {
return false;
}
}
}
return true;
});
self.frag_builder.module_mut().annotations.retain(|instr| {
for op in &instr.operands {
if let Some(word) = op.id_ref_any() {
if self.inputs_to_remove.contains_key(&word) {
return false;
}
}
}
return true;
});
for entry_point in self.frag_builder.module_mut().entry_points.iter_mut() {
entry_point.operands.retain(|op| {
if let Some(word) = op.id_ref_any() {
if self.inputs_to_remove.contains_key(&word) {
return false;
}
}
return true;
})
}
self.frag_builder
.module_mut()
.types_global_values
.retain(|instr| {
let Some(id) = instr.result_id else {
return true;
};
!self.inputs_to_remove.contains_key(&id)
});
}
}