Skip to content

Commit

Permalink
Adding use flash attn to train for minrf
Browse files Browse the repository at this point in the history
Fixed a bug where we don't wait for stream when print.
  • Loading branch information
liuliu committed Jun 11, 2024
1 parent 7a222c9 commit 2c1e63b
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 19 deletions.
4 changes: 2 additions & 2 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")

git_repository(
name = "ccv",
commit = "b6ab47cf662cfc4e5c0236e4184f0274a3df5fe6",
commit = "a33f509e971477ce0d23aca2c2a165819fb2717f",
remote = "https://github.com/liuliu/ccv.git",
shallow_since = "1716395188 -0400",
shallow_since = "1718125376 -0400",
)

load("@ccv//config:ccv.bzl", "ccv_deps", "ccv_setting")
Expand Down
47 changes: 30 additions & 17 deletions examples/minrf/main.swift
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ public func LabelEmbedder<T: TensorNumeric>(_ dataType: T.Type, numClasses: Int,
return labelEmbed
}

func SelfAttention(prefix: String, k: Int, h: Int, hk: Int, b: Int, t: Int) -> Model {
func SelfAttention(
prefix: String, k: Int, h: Int, hk: Int, b: Int, t: Int, usesFlashAttention: Bool
) -> Model {
let x = Input()
let rot = Input()
let tokeys = Dense(count: k * hk, noBias: true, name: "k_proj")
Expand All @@ -50,17 +52,26 @@ func SelfAttention(prefix: String, k: Int, h: Int, hk: Int, b: Int, t: Int) -> M
var keys = k_norm(tokeys(x)).reshaped([b, t, hk, k])
let q_norm = LayerNorm(epsilon: 1e-6, axis: [2], name: "q_norm")
var queries = q_norm(toqueries(x)).reshaped([b, t, h, k])
let values = tovalues(x).reshaped([b, t, hk, k]).transposed(1, 2)
keys = Functional.cmul(left: keys, right: rot)
keys = keys.transposed(1, 2)
queries = Functional.cmul(left: queries, right: rot)
queries = ((1.0 / Float(k).squareRoot()) * queries).transposed(1, 2)
var dot = Matmul(transposeB: (2, 3))(queries, keys)
dot = dot.reshaped([b * h * t, t])
dot = dot.softmax()
dot = dot.reshaped([b, h, t, t])
var out = dot * values
out = out.reshaped([b, h, t, k]).transposed(1, 2).reshaped([b, t, h * k])
var out: Model.IO
if usesFlashAttention {
let values = tovalues(x).reshaped(.NHWC(b, t, hk, k)).to(.Float16)
keys = Functional.cmul(left: keys, right: rot).reshaped(.NHWC(b, t, hk, k)).to(.Float16)
queries = Functional.cmul(left: queries, right: rot).reshaped(.NHWC(b, t, h, k)).to(.Float16)
out = ScaledDotProductAttention(scale: 1.0 / Float(k).squareRoot())(queries, keys, values)
.reshaped(.CHW(b, t, h * k)).to(of: x)
} else {
let values = tovalues(x).reshaped([b, t, hk, k]).transposed(1, 2)
keys = Functional.cmul(left: keys, right: rot)
keys = keys.transposed(1, 2)
queries = Functional.cmul(left: queries, right: rot)
queries = ((1.0 / Float(k).squareRoot()) * queries).transposed(1, 2)
var dot = Matmul(transposeB: (2, 3))(queries, keys)
dot = dot.reshaped([b * h * t, t])
dot = dot.softmax()
dot = dot.reshaped([b, h, t, t])
out = dot * values
out = out.reshaped([b, h, t, k]).transposed(1, 2).reshaped([b, t, h * k])
}
let unifyheads = Dense(count: k * h, noBias: true, name: "out_proj")
out = unifyheads(out)
return Model([x, rot], [out])
Expand All @@ -76,13 +87,14 @@ func FeedForward(hiddenSize: Int, intermediateSize: Int, name: String = "") -> M
return Model([x], [out], name: name)
}

func TransformerBlock(k: Int, h: Int, hk: Int, b: Int, t: Int) -> Model {
func TransformerBlock(k: Int, h: Int, hk: Int, b: Int, t: Int, usesFlashAttention: Bool) -> Model {
let x = Input()
let rot = Input()
let y = Input()
let adaLNs = (0..<6).map { Dense(count: k * h, name: "ada_ln_\($0)") }
let chunks = adaLNs.map { $0(y) }
let attention = SelfAttention(prefix: "", k: k, h: h, hk: hk, b: b, t: t)
let attention = SelfAttention(
prefix: "", k: k, h: h, hk: hk, b: b, t: t, usesFlashAttention: usesFlashAttention)
let attentionNorm = LayerNorm(epsilon: 1e-6, axis: [2], elementwiseAffine: false)
var out = x + chunks[2] .* attention(attentionNorm(x) .* (1 + chunks[1]) + chunks[0], rot)
let ffn = FeedForward(hiddenSize: k * h, intermediateSize: k * h * 3)
Expand All @@ -91,7 +103,7 @@ func TransformerBlock(k: Int, h: Int, hk: Int, b: Int, t: Int) -> Model {
return Model([x, rot, y], [out])
}

func DiT(batchSize: Int, hiddenSize: Int, layers: Int) -> Model {
func DiT(batchSize: Int, hiddenSize: Int, layers: Int, usesFlashAttention: Bool) -> Model {
let x = Input()
let conv0 = Convolution(
groups: 1, filters: hiddenSize / 2, filterSize: [5, 5],
Expand All @@ -116,7 +128,8 @@ func DiT(batchSize: Int, hiddenSize: Int, layers: Int) -> Model {
let adaLNInput = (timestepEmbedder(t) + labelEmbedder(y)).reshaped([batchSize, 1, hiddenSize])
.swish()
for _ in 0..<layers {
let transformer = TransformerBlock(k: hiddenSize / 8, h: 8, hk: 8, b: batchSize, t: 256)
let transformer = TransformerBlock(
k: hiddenSize / 8, h: 8, hk: 8, b: batchSize, t: 256, usesFlashAttention: usesFlashAttention)
out = transformer(out, rot, adaLNInput)
}
let norm = LayerNorm(epsilon: 1e-6, axis: [2], elementwiseAffine: false)
Expand Down Expand Up @@ -193,7 +206,7 @@ if deviceCount > 1 {
let summaryWriter = SummaryWriter(logDirectory: "/tmp/minrf")

let graph = DynamicGraph()
let dit = DiT(batchSize: batchSize, hiddenSize: 256, layers: 10)
let dit = DiT(batchSize: batchSize, hiddenSize: 256, layers: 10, usesFlashAttention: true)
var rot = graph.variable(.CPU, .NCHW(batchSize, 16 * 16, 8, 32), of: Float.self)
for i in 0..<(16 * 16) {
for k in 0..<16 {
Expand Down
2 changes: 2 additions & 0 deletions nnc/DynamicGraph.swift
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,8 @@ extension DynamicGraph.AnyTensor: CustomDebugStringConvertible {
let cmd = ccv_nnc_cmd(
CCV_NNC_DATA_TRANSFER_FORWARD, nil, CmdParamsFactory.factory.newParams(), 0)
ccv_nnc_cmd_exec(cmd, ccv_nnc_no_hint, 0, &_input, 1, &_output, 1, _streamContext)
// Need to wait the stream to be done so we can print current ones.
ccv_nnc_stream_context_wait(_streamContext)
}
defer {
if _output != nil {
Expand Down

0 comments on commit 2c1e63b

Please sign in to comment.