Skip to content

Commit

Permalink
Add gradientCheckpointing flag. Need to do more testing.
Browse files Browse the repository at this point in the history
  • Loading branch information
liuliu committed Mar 20, 2024
1 parent 282aeb4 commit c2ab0fe
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 4 deletions.
4 changes: 2 additions & 2 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")

git_repository(
name = "ccv",
commit = "591e0ebba8b745e6665d5161c50ba09bc5b91394",
commit = "5468f9ed64b4a52ee99844a17a89888084ee8478",
remote = "https://github.com/liuliu/ccv.git",
shallow_since = "1710563054 -0400",
shallow_since = "1710956371 -0400",
)

load("@ccv//config:ccv.bzl", "ccv_deps", "ccv_setting")
Expand Down
4 changes: 2 additions & 2 deletions deps.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ def s4nnc_deps():
git_repository,
name = "ccv",
remote = "https://github.com/liuliu/ccv.git",
commit = "591e0ebba8b745e6665d5161c50ba09bc5b91394",
shallow_since = "1710563054 -0400",
commit = "5468f9ed64b4a52ee99844a17a89888084ee8478",
shallow_since = "1710956371 -0400",
)

_maybe(
Expand Down
9 changes: 9 additions & 0 deletions nnc/Model.swift
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,15 @@ public class Model {
return trainable >= 0 ? trainable != 0 : nil
}

public var gradientCheckpointing: Bool = false {
didSet {
ccv_cnnp_model_set_gradient_checkpointing(cModel, gradientCheckpointing ? 1 : 0)
}
}

/**
* Specify the maximum number of streams we need to allocate to run this model.
*/
public var maxConcurrency: StreamContext.Concurrency = .noLimit {
didSet {
ccv_cnnp_model_set_max_concurrency(cModel, Int32(maxConcurrency.rawValue))
Expand Down

0 comments on commit c2ab0fe

Please sign in to comment.