From d8c771401cdee7865013e08d052025b13cf45d5e Mon Sep 17 00:00:00 2001 From: Clemens Akens Date: Thu, 19 Oct 2023 15:19:34 +0200 Subject: [PATCH] Change computeRMSNorm args --- src/tensor.zig | 30 +++++++++++++++--------------- src/transformer.zig | 10 +++++----- 2 files changed, 20 insertions(+), 20 deletions(-) diff --git a/src/tensor.zig b/src/tensor.zig index 8eee001..820886c 100644 --- a/src/tensor.zig +++ b/src/tensor.zig @@ -50,6 +50,16 @@ pub fn Tensor(comptime n_dims: comptime_int) type { }; } + pub fn add(self: *const Self, other: anytype) void { + @setFloatMode(.Optimized); + + std.debug.assert(self.values.len == other.values.len); + + for (self.values, 0..) |*value, index| { + value.* += other.values[index]; + } + } + pub fn computeMatrixVectorMultiplication( self: *const Self, input: anytype, @@ -76,35 +86,25 @@ pub fn Tensor(comptime n_dims: comptime_int) type { return _computeScalarProduct(4, self, other); } - pub fn add(self: *const Self, other: anytype) void { - @setFloatMode(.Optimized); - - std.debug.assert(self.values.len == other.values.len); - - for (self.values, 0..) |*value, index| { - value.* += other.values[index]; - } - } - // Pre-normalization using RMSNorm: https://arxiv.org/abs/1910.07467 - pub fn computeRMSNorm(self: *const Self, input: anytype, output: anytype) void { + pub fn computeRMSNorm(self: *const Self, weights: anytype, output: anytype) void { @setFloatMode(.Optimized); std.debug.assert(output.values.len == self.values.len); - std.debug.assert(output.values.len == input.values.len); + std.debug.assert(output.values.len == weights.values.len); var rms_scaling_factor: f32 = 0; - for (input.values) |value| { + for (self.values) |value| { rms_scaling_factor += value * value; } - rms_scaling_factor /= @floatFromInt(input.values.len); + rms_scaling_factor /= @floatFromInt(self.values.len); rms_scaling_factor += 1e-5; rms_scaling_factor = 1 / std.math.sqrt(rms_scaling_factor); for (output.values, 0..) |*value, index| { - value.* = self.values[index] * rms_scaling_factor * input.values[index]; + value.* = weights.values[index] * rms_scaling_factor * self.values[index]; } } }; diff --git a/src/transformer.zig b/src/transformer.zig index 78dfbca..ca506f5 100644 --- a/src/transformer.zig +++ b/src/transformer.zig @@ -61,16 +61,16 @@ pub fn forward(self: *const Self, token: usize, position: usize) void { @memcpy(self.hidden_buffer.values, weights.token_embedding_vectors.slice(token).values); for (0..self.checkpoint.n_layers) |layer| { - weights.attention_norm_vectors.slice(layer).computeRMSNorm( - self.hidden_buffer, + self.hidden_buffer.computeRMSNorm( + weights.attention_norm_vectors.slice(layer), self.attention.input_buffer, ); self.attention.forward(layer, position); self.hidden_buffer.add(self.attention.output_buffer); - weights.ffn_norm_vectors.slice(layer).computeRMSNorm( - self.hidden_buffer, + self.hidden_buffer.computeRMSNorm( + weights.ffn_norm_vectors.slice(layer), self.ffn.input_buffer, ); @@ -78,6 +78,6 @@ pub fn forward(self: *const Self, token: usize, position: usize) void { self.hidden_buffer.add(self.ffn.output_buffer); } - weights.output_norm_vector.computeRMSNorm(self.hidden_buffer, self.hidden_buffer); + self.hidden_buffer.computeRMSNorm(weights.output_norm_vector, self.hidden_buffer); weights.output_matrix.computeMatrixVectorMultiplication(self.hidden_buffer, self.output_buffer); }