diff --git a/src/ir/module/mod.rs b/src/ir/module/mod.rs index cacd6dd..5f1665f 100644 --- a/src/ir/module/mod.rs +++ b/src/ir/module/mod.rs @@ -19,10 +19,7 @@ use crate::ir::types::{ BlockType, Body, CustomSections, DataSegment, DataSegmentKind, ElementItems, ElementKind, InstrumentationFlag, }; -use crate::ir::wrappers::{ - indirect_namemap_parser2encoder, namemap_parser2encoder, refers_to_func, refers_to_global, - update_fn_instr, update_global_instr, -}; +use crate::ir::wrappers::{indirect_namemap_parser2encoder, namemap_parser2encoder, refers_to_func, refers_to_global, refers_to_memory, update_fn_instr, update_global_instr, update_memory_instr}; use crate::opcode::{Inject, Instrumenter}; use crate::{Location, Opcode}; use log::{error, warn}; @@ -1086,6 +1083,14 @@ impl<'a> Module<'a> { } else { Self::get_mapping_generic(self.globals.iter()) }; + let memory_mapping = if self.memories.recalculate_ids { + Self::recalculate_ids( + self.imports.num_memories - self.imports.num_memories_added, + &mut self.memories, + ) + } else { + Self::get_mapping_generic(self.memories.iter()) + }; let mut module = wasm_encoder::Module::new(); let mut reencode = RoundtripReencoder; @@ -1238,7 +1243,6 @@ impl<'a> Module<'a> { if !export.deleted { match export.kind { ExternalKind::Func => { - // println!("Export updation {:?}", export.index); // Update the function indices exports.export( &export.name, @@ -1246,6 +1250,14 @@ impl<'a> Module<'a> { *func_mapping.get(&(export.index)).unwrap(), ); } + ExternalKind::Memory => { + // Update the memory indices + exports.export( + &export.name, + wasm_encoder::ExportKind::from(export.kind), + *memory_mapping.get(&(export.index)).unwrap(), + ); + } _ => { exports.export( &export.name, @@ -1373,6 +1385,9 @@ impl<'a> Module<'a> { if refers_to_global(op) { update_global_instr(op, &global_mapping); } + if refers_to_memory(op) { + update_memory_instr(op, &memory_mapping); + } if !instrument.has_instr() { encode(&op.clone(), &mut function, &mut reencode); } else { @@ -1409,6 +1424,7 @@ impl<'a> Module<'a> { before, &func_mapping, &global_mapping, + &memory_mapping, &mut function, &mut reencode, ); @@ -1420,6 +1436,7 @@ impl<'a> Module<'a> { alt, &func_mapping, &global_mapping, + &memory_mapping, &mut function, &mut reencode, ); @@ -1434,6 +1451,7 @@ impl<'a> Module<'a> { after, &func_mapping, &global_mapping, + &memory_mapping, &mut function, &mut reencode, ); @@ -1444,6 +1462,7 @@ impl<'a> Module<'a> { instrs: &mut Vec, func_mapping: &HashMap, global_mapping: &HashMap, + memory_mapping: &HashMap, function: &mut wasm_encoder::Function, reencode: &mut RoundtripReencoder, ) { @@ -1454,6 +1473,9 @@ impl<'a> Module<'a> { if refers_to_global(instr) { update_global_instr(instr, global_mapping); } + if refers_to_memory(instr) { + update_memory_instr(instr, memory_mapping); + } encode(instr, function, reencode); } } @@ -1479,18 +1501,26 @@ impl<'a> Module<'a> { if !self.data.is_empty() { let mut data = wasm_encoder::DataSection::new(); - for segment in self.data.iter() { + for segment in self.data.iter_mut() { let segment_data = segment.data.iter().copied(); - match (*segment).clone().kind { + match &mut segment.kind { DataSegmentKind::Passive => data.passive(segment_data), DataSegmentKind::Active { memory_index, offset_expr, - } => data.active( - memory_index, - &offset_expr.to_wasmencoder_type(), - segment_data, - ), + } => { + let new_idx = match memory_mapping.get(memory_index) { + Some(new_index) => { + *new_index + } + None => panic!("Attempting to reference a deleted memory, ID: {}", memory_index), + }; + data.active( + new_idx, + &offset_expr.to_wasmencoder_type(), + segment_data, + ) + }, }; } module.section(&data); @@ -1585,13 +1615,25 @@ impl<'a> Module<'a> { // ==== Memory Management ==== // =========================== + pub fn add_local_memory( + &mut self, + ty: MemoryType + ) -> MemoryID { + let local_mem = LocalMemory { + mem_id: MemoryID(0), // will be fixed + }; + + self.num_local_memories += 1; + self.memories.add_local_mem(local_mem, ty) + } + pub fn add_import_memory( &mut self, module: String, name: String, ty: MemoryType ) -> (MemoryID, ImportsID) { - let (mem_id, imp_id) = self.add_import(Import { + let (imp_mem_id, imp_id) = self.add_import(Import { module: module.leak(), name: name.clone().leak(), ty: TypeRef::Memory(ty), @@ -1599,7 +1641,20 @@ impl<'a> Module<'a> { deleted: false, }); - (MemoryID(mem_id), imp_id) + // Add to memories as well as it has imported memories + self.memories + .add_import_mem(imp_id, ty, imp_mem_id); + (MemoryID(imp_mem_id), imp_id) + } + + /// Delete a memory from the module. + pub fn delete_memory(&mut self, mem_id: MemoryID) { + self.memories.delete(mem_id); + if let MemKind::Import(ImportedMemory { import_id, .. }) = + self.memories.get_kind(mem_id) + { + self.imports.delete(*import_id); + } } // ============================= diff --git a/src/ir/module/module_memories.rs b/src/ir/module/module_memories.rs index 91a8053..7370c63 100644 --- a/src/ir/module/module_memories.rs +++ b/src/ir/module/module_memories.rs @@ -190,12 +190,12 @@ impl Memory { } } - /// Change the kind of the memory - pub(crate) fn set_kind(&mut self, kind: MemKind) { - self.kind = kind; - // Resets deletion - self.deleted = false; - } + // /// Change the kind of the memory + // pub(crate) fn set_kind(&mut self, kind: MemKind) { + // self.kind = kind; + // // Resets deletion + // self.deleted = false; + // } /// Get the kind of the memory pub fn kind(&self) -> &MemKind { diff --git a/src/ir/wrappers.rs b/src/ir/wrappers.rs index e97f656..29e427f 100644 --- a/src/ir/wrappers.rs +++ b/src/ir/wrappers.rs @@ -364,6 +364,83 @@ pub(crate) fn refers_to_global(op: &Operator) -> bool { ) } +pub(crate) fn refers_to_memory(op: &Operator) -> bool { + matches!( + op, + Operator::I32Load { .. } | + Operator::I32Load8S { .. } | + Operator::I32Load8U { .. } | + Operator::I32Load16S { .. } | + Operator::I32Load16U { .. } | + Operator::I64Load { .. } | + Operator::I64Load8U { .. } | + Operator::I64Load8S { .. } | + Operator::I64Load16U { .. } | + Operator::I64Load16S { .. } | + Operator::I64Load32U { .. } | + Operator::I64Load32S { .. } | + Operator::F32Load { .. } | + Operator::F64Load { .. } | + Operator::V128Load { .. } | + Operator::I32AtomicLoad { .. } | + Operator::I32AtomicLoad8U { .. } | + Operator::I32AtomicLoad16U { .. } | + Operator::I64AtomicLoad8U { .. } | + Operator::I64AtomicLoad16U { .. } | + Operator::I64AtomicLoad32U { .. } | + Operator::V128Load8Lane { .. } | + Operator::V128Load16Lane { .. } | + Operator::V128Load32Lane { .. } | + Operator::V128Load64Lane { .. } | + Operator::V128Load8Splat { .. } | + Operator::V128Load16Splat { .. } | + Operator::V128Load32Splat { .. } | + Operator::V128Load64Splat { .. } | + Operator::V128Load8x8S { .. } | + Operator::V128Load8x8U { .. } | + Operator::V128Load16x4U { .. } | + Operator::V128Load16x4S { .. } | + Operator::V128Load32Zero { .. } | + Operator::V128Load32x2S { .. } | + Operator::V128Load32x2U { .. } | + Operator::V128Load64Zero { .. } | + + // stores + Operator::I32Store { .. } | + Operator::I32Store8 { .. } | + Operator::I32Store16 { .. } | + Operator::I64Store { .. } | + Operator::I64Store8 { .. } | + Operator::I64Store16 { .. } | + Operator::I64Store32 { .. } | + Operator::F32Store { .. } | + Operator::F64Store { .. } | + Operator::I32AtomicStore { .. } | + Operator::I32AtomicStore8 { .. } | + Operator::I32AtomicStore16 { .. } | + Operator::I64AtomicStore { .. } | + Operator::I64AtomicStore8 { .. } | + Operator::I64AtomicStore16 { .. } | + Operator::I64AtomicStore32 { .. } | + Operator::V128Store { .. } | + Operator::V128Store8Lane { .. } | + Operator::V128Store16Lane { .. } | + Operator::V128Store32Lane { .. } | + Operator::V128Store64Lane { .. } | + + // memory operations + Operator::MemoryAtomicNotify { .. } | + Operator::MemoryAtomicWait32 { .. } | + Operator::MemoryAtomicWait64 { .. } | + Operator::MemoryGrow { .. } | + Operator::MemoryFill { .. } | + Operator::MemoryInit { .. } | + Operator::MemorySize { .. } | + Operator::MemoryDiscard { .. } | + Operator::MemoryCopy { .. } + ) +} + pub(crate) fn update_fn_instr(op: &mut Operator, mapping: &HashMap) { match op { Operator::Call { function_index } | Operator::RefFunc { function_index } => { @@ -401,3 +478,108 @@ pub(crate) fn update_global_instr(op: &mut Operator, mapping: &HashMap _ => panic!("Operation doesn't need to be checked for global IDs!"), } } + +pub(crate) fn update_memory_instr(op: &mut Operator, mapping: &HashMap) { + match op { + // loads + Operator::I32Load { memarg } | + Operator::I32Load8S { memarg } | + Operator::I32Load8U { memarg } | + Operator::I32Load16S { memarg } | + Operator::I32Load16U { memarg } | + Operator::I64Load { memarg } | + Operator::I64Load8U { memarg } | + Operator::I64Load8S { memarg } | + Operator::I64Load16U { memarg } | + Operator::I64Load16S { memarg } | + Operator::I64Load32U { memarg } | + Operator::I64Load32S { memarg } | + Operator::F32Load { memarg } | + Operator::F64Load { memarg } | + Operator::V128Load { memarg } | + Operator::I32AtomicLoad { memarg } | + Operator::I32AtomicLoad8U { memarg } | + Operator::I32AtomicLoad16U { memarg } | + Operator::I64AtomicLoad8U { memarg } | + Operator::I64AtomicLoad16U { memarg } | + Operator::I64AtomicLoad32U { memarg } | + Operator::V128Load8Lane { memarg, .. } | + Operator::V128Load16Lane { memarg, .. } | + Operator::V128Load32Lane { memarg, .. } | + Operator::V128Load64Lane { memarg, .. } | + Operator::V128Load8Splat { memarg } | + Operator::V128Load16Splat { memarg } | + Operator::V128Load32Splat { memarg } | + Operator::V128Load64Splat { memarg } | + Operator::V128Load8x8S { memarg } | + Operator::V128Load8x8U { memarg } | + Operator::V128Load16x4U { memarg } | + Operator::V128Load16x4S { memarg } | + Operator::V128Load32Zero { memarg } | + Operator::V128Load32x2S { memarg } | + Operator::V128Load32x2U { memarg } | + Operator::V128Load64Zero { memarg } | + + // stores + Operator::I32Store {memarg} | + Operator::I32Store8 {memarg} | + Operator::I32Store16 {memarg} | + Operator::I64Store {memarg} | + Operator::I64Store8 {memarg} | + Operator::I64Store16 {memarg} | + Operator::I64Store32 {memarg} | + Operator::F32Store {memarg} | + Operator::F64Store {memarg} | + Operator::I32AtomicStore {memarg} | + Operator::I32AtomicStore8 {memarg} | + Operator::I32AtomicStore16 {memarg} | + Operator::I64AtomicStore {memarg} | + Operator::I64AtomicStore8 {memarg} | + Operator::I64AtomicStore16 {memarg} | + Operator::I64AtomicStore32 {memarg} | + Operator::V128Store {memarg} | + Operator::V128Store8Lane {memarg, ..} | + Operator::V128Store16Lane {memarg, ..} | + Operator::V128Store32Lane {memarg, ..} | + Operator::V128Store64Lane {memarg, ..} | + + // memory operations + Operator::MemoryAtomicNotify {memarg} | + Operator::MemoryAtomicWait32 {memarg} | + Operator::MemoryAtomicWait64 {memarg} => { + match mapping.get(&(memarg.memory)) { + Some(new_index) => { + memarg.memory = *new_index; + } + None => panic!("Attempting to reference a deleted memory, ID: {}", memarg.memory), + } + } + Operator::MemoryGrow {mem} | + Operator::MemoryFill {mem} | + Operator::MemoryInit {mem, ..} | + Operator::MemorySize {mem} | + Operator::MemoryDiscard {mem} => { + match mapping.get(mem) { + Some(new_index) => { + *mem = *new_index; + } + None => panic!("Attempting to reference a deleted memory, ID: {}", mem), + } + } + Operator::MemoryCopy {src_mem, dst_mem} => { + match mapping.get(src_mem) { + Some(new_index) => { + *src_mem = *new_index; + } + None => panic!("Attempting to reference a deleted memory, ID: {}", src_mem), + } + match mapping.get(dst_mem) { + Some(new_index) => { + *dst_mem = *new_index; + } + None => panic!("Attempting to reference a deleted memory, ID: {}", dst_mem), + } + } + _ => panic!("Operation doesn't need to be checked for memory IDs!"), + } +}