Skip to content

Commit

Permalink
Make the API more backward compatible.
Browse files Browse the repository at this point in the history
  • Loading branch information
liuliu committed Jul 22, 2024
1 parent 5568fc3 commit 1ae29f1
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 4 deletions.
32 changes: 31 additions & 1 deletion nnc/Store.swift
Original file line number Diff line number Diff line change
Expand Up @@ -4499,7 +4499,7 @@ extension DynamicGraph {
* - reader: You can customize your reader to load parameter with a different name etc.
*/
public func read(
_ key: String, model: Model, strict: Bool = false, codec: Codec = [],
_ key: String, model: Model, strict: Bool, codec: Codec = [],
reader: ((String, DataType, TensorFormat, TensorShape) -> ModelReaderResult)? = nil
) throws {
guard let reader = reader else {
Expand Down Expand Up @@ -4564,6 +4564,21 @@ extension DynamicGraph {
cString: ccv_cnnp_model_parameter_name(model.cModel, _io)))
}
}
/**
* Read parameters into a given model.
*
* - Parameters:
* - key: The key corresponding to a particular model.
* - model: The model to be initialized with parameters from a given key.
* - codec: The codec for potential encoded parameters.
* - reader: You can customize your reader to load parameter with a different name etc.
*/
public func read(
_ key: String, model: Model, codec: Codec = [],
reader: ((String, DataType, TensorFormat, TensorShape) -> ModelReaderResult)? = nil
) {
try? read(key, model: model, strict: false, codec: codec, reader: reader)
}
/**
* Read parameters into a given model builder.
*
Expand All @@ -4580,6 +4595,21 @@ extension DynamicGraph {
) throws {
try model.read(key, from: store, strict: strict, codec: codec, reader: reader)
}
/**
* Read parameters into a given model builder.
*
* - Parameters:
* - key: The key corresponding to a particular model.
* - model: The model builder to be initialized with parameters from a given key.
* - codec: The codec for potential encoded parameters.
* - reader: You can customize your reader to load parameter with a different name etc.
*/
public func read(
_ key: String, model: AnyModelBuilder, codec: Codec = [],
reader: ((String, DataType, TensorFormat, TensorShape) -> ModelReaderResult)? = nil
) {
try? read(key, model: model, strict: false, codec: codec, reader: reader)
}

/**
* Write a tensor to the store.
Expand Down
6 changes: 3 additions & 3 deletions test/store.swift
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ final class StoreTests: XCTestCase {
linear1.compile(inputs: tv0)
graph.openStore("test/model.db") { store in
store.write("a", model: linear0)
try! store.read("a", model: linear1) { name, _, _, _ in
store.read("a", model: linear1) { name, _, _, _ in
return .continue("__a__[t-linear-0-0]")
}
}
Expand All @@ -157,7 +157,7 @@ final class StoreTests: XCTestCase {
linear1.compile(inputs: tv0)
graph.openStore("test/model.db") { store in
store.write("a", model: linear0)
try! store.read("a", model: linear1) { name, _, format, shape in
store.read("a", model: linear1) { name, _, format, shape in
var a = Tensor<Float32>(.CPU, format: format, shape: shape)
a[0, 0] = 2
return .final(a)
Expand All @@ -178,7 +178,7 @@ final class StoreTests: XCTestCase {
store.write("a", model: linear0) { name, _ in
return .continue("__a__[t-0-0]")
}
try! store.read("a", model: linear1)
store.read("a", model: linear1)
}
let tv2 = linear1(inputs: tv0)[0].as(of: Float32.self)
XCTAssertEqual(tv1[0], tv2[0])
Expand Down

0 comments on commit 1ae29f1

Please sign in to comment.