Skip to content

Commit

Permalink
Change computeRMSNorm args
Browse files Browse the repository at this point in the history
  • Loading branch information
clebert committed Oct 19, 2023
1 parent 84f4e7f commit d8c7714
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 20 deletions.
30 changes: 15 additions & 15 deletions src/tensor.zig
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,16 @@ pub fn Tensor(comptime n_dims: comptime_int) type {
};
}

pub fn add(self: *const Self, other: anytype) void {
@setFloatMode(.Optimized);

std.debug.assert(self.values.len == other.values.len);

for (self.values, 0..) |*value, index| {
value.* += other.values[index];
}
}

pub fn computeMatrixVectorMultiplication(
self: *const Self,
input: anytype,
Expand All @@ -76,35 +86,25 @@ pub fn Tensor(comptime n_dims: comptime_int) type {
return _computeScalarProduct(4, self, other);
}

pub fn add(self: *const Self, other: anytype) void {
@setFloatMode(.Optimized);

std.debug.assert(self.values.len == other.values.len);

for (self.values, 0..) |*value, index| {
value.* += other.values[index];
}
}

// Pre-normalization using RMSNorm: https://arxiv.org/abs/1910.07467
pub fn computeRMSNorm(self: *const Self, input: anytype, output: anytype) void {
pub fn computeRMSNorm(self: *const Self, weights: anytype, output: anytype) void {
@setFloatMode(.Optimized);

std.debug.assert(output.values.len == self.values.len);
std.debug.assert(output.values.len == input.values.len);
std.debug.assert(output.values.len == weights.values.len);

var rms_scaling_factor: f32 = 0;

for (input.values) |value| {
for (self.values) |value| {
rms_scaling_factor += value * value;
}

rms_scaling_factor /= @floatFromInt(input.values.len);
rms_scaling_factor /= @floatFromInt(self.values.len);
rms_scaling_factor += 1e-5;
rms_scaling_factor = 1 / std.math.sqrt(rms_scaling_factor);

for (output.values, 0..) |*value, index| {
value.* = self.values[index] * rms_scaling_factor * input.values[index];
value.* = weights.values[index] * rms_scaling_factor * self.values[index];
}
}
};
Expand Down
10 changes: 5 additions & 5 deletions src/transformer.zig
Original file line number Diff line number Diff line change
Expand Up @@ -61,23 +61,23 @@ pub fn forward(self: *const Self, token: usize, position: usize) void {
@memcpy(self.hidden_buffer.values, weights.token_embedding_vectors.slice(token).values);

for (0..self.checkpoint.n_layers) |layer| {
weights.attention_norm_vectors.slice(layer).computeRMSNorm(
self.hidden_buffer,
self.hidden_buffer.computeRMSNorm(
weights.attention_norm_vectors.slice(layer),
self.attention.input_buffer,
);

self.attention.forward(layer, position);
self.hidden_buffer.add(self.attention.output_buffer);

weights.ffn_norm_vectors.slice(layer).computeRMSNorm(
self.hidden_buffer,
self.hidden_buffer.computeRMSNorm(
weights.ffn_norm_vectors.slice(layer),
self.ffn.input_buffer,
);

self.ffn.forward(layer);
self.hidden_buffer.add(self.ffn.output_buffer);
}

weights.output_norm_vector.computeRMSNorm(self.hidden_buffer, self.hidden_buffer);
self.hidden_buffer.computeRMSNorm(weights.output_norm_vector, self.hidden_buffer);
weights.output_matrix.computeMatrixVectorMultiplication(self.hidden_buffer, self.output_buffer);
}

0 comments on commit d8c7714

Please sign in to comment.