Skip to content

Commit

Permalink
Update to expose these two properties better.
Browse files Browse the repository at this point in the history
  • Loading branch information
liuliu committed Mar 21, 2024
1 parent c2ab0fe commit e5f9c06
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 6 deletions.
4 changes: 2 additions & 2 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")

git_repository(
name = "ccv",
commit = "5468f9ed64b4a52ee99844a17a89888084ee8478",
commit = "a77e2b18989d009776282b33c3dd16fb968ee346",
remote = "https://github.com/liuliu/ccv.git",
shallow_since = "1710956371 -0400",
shallow_since = "1710999193 -0400",
)

load("@ccv//config:ccv.bzl", "ccv_deps", "ccv_setting")
Expand Down
4 changes: 2 additions & 2 deletions deps.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ def s4nnc_deps():
git_repository,
name = "ccv",
remote = "https://github.com/liuliu/ccv.git",
commit = "5468f9ed64b4a52ee99844a17a89888084ee8478",
shallow_since = "1710956371 -0400",
commit = "a77e2b18989d009776282b33c3dd16fb968ee346",
shallow_since = "1710999193 -0400",
)

_maybe(
Expand Down
21 changes: 19 additions & 2 deletions nnc/Model.swift
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,26 @@ public class Model {
return trainable >= 0 ? trainable != 0 : nil
}

public var gradientCheckpointing: Bool = false {
/**
* Whether to enable gradient checkpointing for this model. Once it is enabled, we will re-run
* the model forward pass again during backward pass. This is effective at reducing memory usage.
*/
public var gradientCheckpointing: Bool {
get {
ccv_cnnp_model_gradient_checkpointing(cModel) != 0
}
set {
ccv_cnnp_model_set_gradient_checkpointing(cModel, newValue ? 1 : 0)
}
}

/**
* Whether to enable memory reduction for this model. The current supported memory reduction
* technique is to redo datatype conversion during backward pass if needed.
*/
public var memoryReduction: Bool = false {
didSet {
ccv_cnnp_model_set_gradient_checkpointing(cModel, gradientCheckpointing ? 1 : 0)
ccv_cnnp_model_set_memory_reduction(cModel, memoryReduction ? 1 : 0)
}
}

Expand Down

0 comments on commit e5f9c06

Please sign in to comment.