Skip to content

Commit

Permalink
Support / op in Model. This should have all facilities for Mistral mo…
Browse files Browse the repository at this point in the history
…del?
  • Loading branch information
liuliu committed Dec 15, 2023
1 parent 9a9b0ef commit ab8a94d
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 7 deletions.
4 changes: 2 additions & 2 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")

git_repository(
name = "ccv",
commit = "8e4a5a05201c1cf9d966a283c014b182c14368b5",
commit = "a916bbf1b7d05f3d1573bd2f9a0108c07c098a7c",
remote = "https://github.com/liuliu/ccv.git",
shallow_since = "1702616811 -0500",
shallow_since = "1702668857 -0500",
)

load("@ccv//config:ccv.bzl", "ccv_deps", "ccv_setting")
Expand Down
4 changes: 2 additions & 2 deletions deps.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ def s4nnc_deps():
git_repository,
name = "ccv",
remote = "https://github.com/liuliu/ccv.git",
commit = "8e4a5a05201c1cf9d966a283c014b182c14368b5",
shallow_since = "1702616811 -0500",
commit = "a916bbf1b7d05f3d1573bd2f9a0108c07c098a7c",
shallow_since = "1702668857 -0500",
)

_maybe(
Expand Down
27 changes: 27 additions & 0 deletions nnc/ModelAddons.swift
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,33 @@ public final class Mul: Model {
}
}

/// Div two inputs together. It will not do broadcast.
public final class Div: Model {
required init(_ model: OpaquePointer) {
super.init(model)
}

public init(reciprocal: Bool = false, name: String = "") {
super.init(ccv_cnnp_div(reciprocal ? 1 : 0, name))
}

public func callAsFunction<T: DynamicGraph.TensorGroup>(
_ left: T, _ right: T, streamContext: StreamContext? = nil
) -> T {
let outputs = self(inputs: left, right, streamContext: streamContext)
return T(outputs[0])
}
}

extension ModelIOConvertible {
/**
* Compute the reciprocal for a model IO.
*/
public func reciprocal() -> Model.IO {
return Div(reciprocal: true)(self)
}
}

/// Matrix-multiplication over two inputs.
public final class Matmul: Model {
required init(_ model: OpaquePointer) {
Expand Down
26 changes: 23 additions & 3 deletions nnc/Operators.swift
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
infix operator .*: MultiplicationPrecedence
infix operator ./: MultiplicationPrecedence
infix operator .+: AdditionPrecedence
infix operator .* : MultiplicationPrecedence
infix operator ./ : MultiplicationPrecedence
infix operator .+ : AdditionPrecedence

// Element-wise addition
public func .+ <T: DynamicGraph.TensorGroup>(left: T, right: T) -> T {
Expand All @@ -16,6 +16,10 @@ public func ./ <T: DynamicGraph.TensorGroup>(left: T, right: T) -> T {
return Functional.div(left: left, right: right)
}

public func ./ (left: ModelIOConvertible, right: ModelIOConvertible) -> Model.IO {
return Div(reciprocal: false)(left, right)
}

// Broadcast element-wise multiplication
public func .* <T: DynamicGraph.TensorGroup>(left: T, right: T) -> T {
return Functional.mul(left: left, right: right)
Expand Down Expand Up @@ -84,6 +88,22 @@ public func + (left: ModelIOConvertible, right: Float) -> Model.IO {
return Add()(left, Scalar(value: right)(left))
}

public func / (left: Float, right: ModelIOConvertible) -> Model.IO {
if left == 1 {
return Div(reciprocal: true)(right)
} else {
return left * Div(reciprocal: true)(right)
}
}

public func / (left: ModelIOConvertible, right: Float) -> Model.IO {
if right == 1 {
return left.io
} else {
return Scalmul(1.0 / right)(left)
}
}

// Broadcast element-wise subtraction.
public func - <T: DynamicGraph.TensorGroup>(left: T, right: T) -> T {
return Functional.add(left: left, right: right, rightScalar: -1)
Expand Down
22 changes: 22 additions & 0 deletions test/model.swift
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,27 @@ final class ModelTests: XCTestCase {
XCTAssertEqual(tv3.rawValue[0], 1.1 * 2.2 + 3.1, accuracy: 1e-5)
}

func testModelDiv() throws {
let dynamicGraph = DynamicGraph()

func DivRec() -> Model {
let i0 = Input()
let i1 = Input()
let i2 = i0 ./ i1
let i3 = Input()
let i4 = 0.5 / i3
return Model([i0, i1, i3], [i2, i4])
}

let div = DivRec()
let tv0 = dynamicGraph.variable(Tensor<Float32>([1.1], .CPU, .C(1)))
let tv1 = dynamicGraph.variable(Tensor<Float32>([2.2], .CPU, .C(1)))
let tv2 = dynamicGraph.variable(Tensor<Float32>([0.2], .CPU, .C(1)))
let tv3s = div(inputs: tv0, tv1, tv2).map { $0.as(of: Float32.self) }
XCTAssertEqual(tv3s[0].rawValue[0], 1.1 / 2.2, accuracy: 1e-5)
XCTAssertEqual(tv3s[1].rawValue[0], 0.5 / 0.2, accuracy: 1e-5)
}

func testModelScaledDotProductAttention() throws {
let dynamicGraph = DynamicGraph()
let q = dynamicGraph.variable(Tensor<Float32>(.CPU, .NHWC(1, 10, 8, 20)))
Expand Down Expand Up @@ -347,6 +368,7 @@ final class ModelTests: XCTestCase {
("testSequential", testSequential),
("testModelWithScalar", testModelWithScalar),
("testModelWithParameter", testModelWithParameter),
("testModelDiv", testModelDiv),
("testModelScaledDotProductAttention", testModelScaledDotProductAttention),
("testCustomModel", testCustomModel),
]
Expand Down

0 comments on commit ab8a94d

Please sign in to comment.