Skip to content

Commit

Permalink
Add chunked op, making AdaLN easier to implement.
Browse files Browse the repository at this point in the history
  • Loading branch information
liuliu committed Mar 16, 2024
1 parent 04fbc84 commit 282aeb4
Show file tree
Hide file tree
Showing 6 changed files with 105 additions and 4 deletions.
4 changes: 2 additions & 2 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")

git_repository(
name = "ccv",
commit = "bd451f3621ba451fe7bb16a0afa767267994943f",
commit = "591e0ebba8b745e6665d5161c50ba09bc5b91394",
remote = "https://github.com/liuliu/ccv.git",
shallow_since = "1710546673 -0400",
shallow_since = "1710563054 -0400",
)

load("@ccv//config:ccv.bzl", "ccv_deps", "ccv_setting")
Expand Down
4 changes: 2 additions & 2 deletions deps.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ def s4nnc_deps():
git_repository,
name = "ccv",
remote = "https://github.com/liuliu/ccv.git",
commit = "bd451f3621ba451fe7bb16a0afa767267994943f",
shallow_since = "1710546673 -0400",
commit = "591e0ebba8b745e6665d5161c50ba09bc5b91394",
shallow_since = "1710563054 -0400",
)

_maybe(
Expand Down
38 changes: 38 additions & 0 deletions nnc/FunctionalAddons.swift
Original file line number Diff line number Diff line change
Expand Up @@ -1759,3 +1759,41 @@ extension DynamicGraph.AnyTensor {
return isNaN != 0
}
}

extension DynamicGraph.Tensor {
public func chunked(_ numberOfChunks: Int, axis: Int, streamContext: StreamContext?)
-> [DynamicGraph.Tensor<Element>]
{
var shape = shape
precondition(axis < shape.count)
precondition((shape[axis] % numberOfChunks) == 0)
shape[axis] = shape[axis] / numberOfChunks
var offset = TensorShape([])
let strides = strides
let format = format
return (0..<numberOfChunks).map {
offset[axis] = shape[axis] * $0
return reshaped(format: format, shape: shape, offset: offset, strides: strides)
}
}
}

extension DynamicGraph.Group where Element: DynamicGraph.AnyTensor {
public func chunked(_ numberOfChunks: Int, axis: Int, streamContext: StreamContext?) -> [Self] {
var shape = shape
precondition(axis < shape.count)
precondition((shape[axis] % numberOfChunks) == 0)
shape[axis] = shape[axis] / numberOfChunks
let result = underlyingArray.map { tensor in
var offset = TensorShape([])
let strides = tensor.strides
return (0..<numberOfChunks).map {
offset[axis] = shape[axis] * $0
return tensor.reshaped(format: format, shape: shape, offset: offset, strides: strides)
}
}
return (0..<numberOfChunks).map { index in
DynamicGraph.Group(result.map { $0[index] })
}
}
}
26 changes: 26 additions & 0 deletions nnc/ModelAddons.swift
Original file line number Diff line number Diff line change
Expand Up @@ -1144,6 +1144,32 @@ extension Functional {
public static func concat(axis: Int, _ inputs: ModelIOConvertible...) -> Model.IO {
return Concat(axis: axis).apply(inputs)
}

public static func concat<T: DynamicGraph.TensorGroup>(
axis: Int, _ inputs: T..., streamContext: StreamContext? = nil
) -> T {
let outputs = Concat(axis: axis)(
inputs: inputs[0], Array(inputs.suffix(from: 1)), streamContext: streamContext)
return T(outputs[0])
}
}

/// Chunk model.
public final class Chunk: Model {
required init(_ model: OpaquePointer) {
super.init(model)
}

public init(_ numberOfChunks: Int, axis: Int, name: String = "") {
super.init(ccv_cnnp_chunk(Int32(numberOfChunks), Int32(axis), name))
}
}

extension ModelIOConvertible {
public func chunked(_ numberOfChunks: Int, axis: Int) -> [Model.IO] {
let result = Chunk(numberOfChunks, axis: axis)(self)
return (0..<numberOfChunks).map { result[$0] }
}
}

/// LSTM model.
Expand Down
9 changes: 9 additions & 0 deletions nnc/TensorGroup.swift
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ public protocol DynamicGraph_TensorGroup: DynamicGraph_AnyTensorGroup {
func tanh(streamContext: StreamContext?)
/// Apply swish activation to the given tensor inplace.
func swish(streamContext: StreamContext?)
/// Chunk the current tensor into multiple ones.
func chunked(_ numberOfChunks: Int, axis: Int, streamContext: StreamContext?) -> [Self]
}

extension DynamicGraph_TensorGroup {
Expand Down Expand Up @@ -349,6 +351,13 @@ extension DynamicGraph_TensorGroup {
public func swish(streamContext: StreamContext? = nil) {
swish(streamContext: streamContext)
}
/// Chunk the current tensor into multiple ones.
@inlinable
public func chunked(_ numberOfChunks: Int, axis: Int = 0, streamContext: StreamContext? = nil)
-> [Self]
{
chunked(numberOfChunks, axis: axis, streamContext: streamContext)
}
}

extension DynamicGraph {
Expand Down
28 changes: 28 additions & 0 deletions test/ops.swift
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,32 @@ final class OpsTests: XCTestCase {
XCTAssertEqual(a4.rawValue[1, 1], 4.4 - 5)
}

func testOpChunked() throws {
let dynamicGraph = DynamicGraph()
let a0 = dynamicGraph.variable(Tensor<Float32>([1.1, 2.2, 3.3, 4.4], .CPU, .NC(2, 2)))
let b = a0.chunked(2)
XCTAssertEqual(b[0].rawValue[0, 0], 1.1)
XCTAssertEqual(b[0].rawValue[0, 1], 2.2)
XCTAssertEqual(b[1].rawValue[0, 0], 3.3)
XCTAssertEqual(b[1].rawValue[0, 1], 4.4)
}

func testOpChunkedGroup() throws {
let dynamicGraph = DynamicGraph()
let a0 = dynamicGraph.variable(Tensor<Float32>([1.1, 2.2, 3.3, 4.4], .CPU, .NC(2, 2)))
let a1 = dynamicGraph.variable(Tensor<Float32>([1.2, 2.4, 3.6, 4.8], .CPU, .NC(2, 2)))
let a = DynamicGraph.Group(a0, a1)
let b = a.chunked(2)
XCTAssertEqual(b[0][0].rawValue[0, 0], 1.1)
XCTAssertEqual(b[0][0].rawValue[0, 1], 2.2)
XCTAssertEqual(b[1][0].rawValue[0, 0], 3.3)
XCTAssertEqual(b[1][0].rawValue[0, 1], 4.4)
XCTAssertEqual(b[0][1].rawValue[0, 0], 1.2)
XCTAssertEqual(b[0][1].rawValue[0, 1], 2.4)
XCTAssertEqual(b[1][1].rawValue[0, 0], 3.6)
XCTAssertEqual(b[1][1].rawValue[0, 1], 4.8)
}

func testReduceSumModel() throws {
let dynamicGraph = DynamicGraph()
let input = Input()
Expand Down Expand Up @@ -169,6 +195,8 @@ final class OpsTests: XCTestCase {
("testReduceMean", testReduceMean),
("testReduceMax", testReduceMax),
("testOpAdd", testOpAdd),
("testOpChunked", testOpChunked),
("testOpChunkedGroup", testOpChunkedGroup),
("testReduceSumModel", testReduceSumModel),
("testReduceMeanModel", testReduceMeanModel),
("testReduceMaxModel", testReduceMaxModel),
Expand Down

0 comments on commit 282aeb4

Please sign in to comment.