From 1ae29f152d65f6f1254dbec0a55ceb13d78e98f9 Mon Sep 17 00:00:00 2001 From: Liu Liu Date: Mon, 22 Jul 2024 18:06:31 -0400 Subject: [PATCH] Make the API more backward compatible. --- nnc/Store.swift | 32 +++++++++++++++++++++++++++++++- test/store.swift | 6 +++--- 2 files changed, 34 insertions(+), 4 deletions(-) diff --git a/nnc/Store.swift b/nnc/Store.swift index 2ecb7acd39a..f8028516d86 100644 --- a/nnc/Store.swift +++ b/nnc/Store.swift @@ -4499,7 +4499,7 @@ extension DynamicGraph { * - reader: You can customize your reader to load parameter with a different name etc. */ public func read( - _ key: String, model: Model, strict: Bool = false, codec: Codec = [], + _ key: String, model: Model, strict: Bool, codec: Codec = [], reader: ((String, DataType, TensorFormat, TensorShape) -> ModelReaderResult)? = nil ) throws { guard let reader = reader else { @@ -4564,6 +4564,21 @@ extension DynamicGraph { cString: ccv_cnnp_model_parameter_name(model.cModel, _io))) } } + /** + * Read parameters into a given model. + * + * - Parameters: + * - key: The key corresponding to a particular model. + * - model: The model to be initialized with parameters from a given key. + * - codec: The codec for potential encoded parameters. + * - reader: You can customize your reader to load parameter with a different name etc. + */ + public func read( + _ key: String, model: Model, codec: Codec = [], + reader: ((String, DataType, TensorFormat, TensorShape) -> ModelReaderResult)? = nil + ) { + try? read(key, model: model, strict: false, codec: codec, reader: reader) + } /** * Read parameters into a given model builder. * @@ -4580,6 +4595,21 @@ extension DynamicGraph { ) throws { try model.read(key, from: store, strict: strict, codec: codec, reader: reader) } + /** + * Read parameters into a given model builder. + * + * - Parameters: + * - key: The key corresponding to a particular model. + * - model: The model builder to be initialized with parameters from a given key. + * - codec: The codec for potential encoded parameters. + * - reader: You can customize your reader to load parameter with a different name etc. + */ + public func read( + _ key: String, model: AnyModelBuilder, codec: Codec = [], + reader: ((String, DataType, TensorFormat, TensorShape) -> ModelReaderResult)? = nil + ) { + try? read(key, model: model, strict: false, codec: codec, reader: reader) + } /** * Write a tensor to the store. diff --git a/test/store.swift b/test/store.swift index acc652d3bc3..6fb656ce193 100644 --- a/test/store.swift +++ b/test/store.swift @@ -140,7 +140,7 @@ final class StoreTests: XCTestCase { linear1.compile(inputs: tv0) graph.openStore("test/model.db") { store in store.write("a", model: linear0) - try! store.read("a", model: linear1) { name, _, _, _ in + store.read("a", model: linear1) { name, _, _, _ in return .continue("__a__[t-linear-0-0]") } } @@ -157,7 +157,7 @@ final class StoreTests: XCTestCase { linear1.compile(inputs: tv0) graph.openStore("test/model.db") { store in store.write("a", model: linear0) - try! store.read("a", model: linear1) { name, _, format, shape in + store.read("a", model: linear1) { name, _, format, shape in var a = Tensor(.CPU, format: format, shape: shape) a[0, 0] = 2 return .final(a) @@ -178,7 +178,7 @@ final class StoreTests: XCTestCase { store.write("a", model: linear0) { name, _ in return .continue("__a__[t-0-0]") } - try! store.read("a", model: linear1) + store.read("a", model: linear1) } let tv2 = linear1(inputs: tv0)[0].as(of: Float32.self) XCTAssertEqual(tv1[0], tv2[0])