Skip to content

Commit

Permalink
Make some more things public (#91)
Browse files Browse the repository at this point in the history
Also improves coefficient encoding/decoding.
Also moves decoding to Scheme.{Signed}Scalar.
  • Loading branch information
fboemer authored Sep 5, 2024
1 parent 562dd5f commit 5052bab
Show file tree
Hide file tree
Showing 15 changed files with 138 additions and 101 deletions.
54 changes: 39 additions & 15 deletions Sources/HomomorphicEncryption/Array2d.swift
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,15 @@ public struct Array2d<T: Equatable & AdditiveArithmetic & Sendable>: Equatable,
@usableFromInline package var rowCount: Int
@usableFromInline package var columnCount: Int

@usableFromInline package var shape: (Int, Int) { (rowCount, columnCount) }
@usableFromInline package var count: Int { rowCount * columnCount }
/// The row and column counts.
public var shape: (rowCount: Int, columnCount: Int) { (rowCount: rowCount, columnCount: columnCount) }
/// The number of entries in the array.
public var count: Int { rowCount * columnCount }

/// Creates a new ``Array2d``.
/// - Parameter data: Row-major entries of the array. Each row must have the same number of entries.
@inlinable
package init(data: [[T]] = []) {
public init(data: [[T]] = []) {
if data.isEmpty {
self.init(data: [], rowCount: 0, columnCount: 0)
} else {
Expand All @@ -36,23 +40,36 @@ public struct Array2d<T: Equatable & AdditiveArithmetic & Sendable>: Equatable,
}
}

/// Creates a new ``Array2d``.
/// - Parameters:
/// - data: Row-major entries of the array. Must have `rowCount * columnCount` entries.
/// - rowCount: Number of rows. Must be non-negative.
/// - columnCount: Number of columns. Must be non-negative.
@inlinable
package init(data: [T], rowCount: Int, columnCount: Int) {
precondition(data.count == rowCount * columnCount)
public init(data: [T], rowCount: Int, columnCount: Int) {
precondition(data.count == rowCount * columnCount, "Wrong data count \(data.count)")
precondition(rowCount >= 0 && columnCount >= 0)
self.data = data
self.rowCount = rowCount
self.columnCount = columnCount
}

/// Creates a new ``Array2d`` from an existing array.
/// - Parameter array: Existing array; must have entry type representable by `T`.
@inlinable
init(array: Array2d<some FixedWidthInteger>) where T: FixedWidthInteger {
public init(array: Array2d<some FixedWidthInteger>) where T: FixedWidthInteger {
self.columnCount = array.columnCount
self.rowCount = array.rowCount
self.data = array.data.map { T($0) }
}

/// Creates a new array of zeros.
/// - Parameters:
/// - rowCount: Number of rows.
/// - columnCount: Number of columns.
/// - Returns: The array of zeros.
@inlinable
package static func zero(rowCount: Int, columnCount: Int) -> Self {
public static func zero(rowCount: Int, columnCount: Int) -> Self {
self.init(
data: [T](Array(repeating: T.zero, count: rowCount * columnCount)),
rowCount: rowCount,
Expand All @@ -62,30 +79,33 @@ public struct Array2d<T: Equatable & AdditiveArithmetic & Sendable>: Equatable,

extension Array2d {
@inlinable
package func index(row: Int, column: Int) -> Int {
func index(row: Int, column: Int) -> Int {
row &* columnCount &+ column
}

@inlinable
package func rowIndices(row: Int) -> Range<Int> {
func rowIndices(row: Int) -> Range<Int> {
index(row: row, column: 0)..<index(row: row, column: columnCount)
}

@inlinable
package func columnIndices(column: Int) -> StrideTo<Int> {
func columnIndices(column: Int) -> StrideTo<Int> {
stride(from: index(row: 0, column: column), to: index(row: rowCount, column: column), by: columnCount)
}

/// Returns the entries in the row.
/// - Parameter row: Index of the row. Must be in `[0, rowCount)`.
/// - Returns: The entries in the row.
@inlinable
package func row(row: Int) -> [T] {
public func row(_ row: Int) -> [T] {
Array(data[rowIndices(row: row)])
}

/// Gathers array values into an array.
/// - Parameter indices: Indices whose values to gather.
/// - Returns: The values of the array in order of the given indices.
@inlinable
public func collectValues(indices: any Sequence<Int>) -> [T] {
func collectValues(indices: any Sequence<Int>) -> [T] {
indices.map { data[$0] }
}

Expand All @@ -106,7 +126,7 @@ extension Array2d {
}

@inlinable
package subscript(_ index: Int) -> T {
subscript(_ index: Int) -> T {
get {
data[index]
}
Expand All @@ -115,8 +135,12 @@ extension Array2d {
}
}

/// Access for the `(row, column)` entry.
/// - Parameters:
/// - `row`: Must be in `[0, rowCount)`
/// - `column`: Must be in `[0, columnCount)`
@inlinable
package subscript(_ row: Int, _ column: Int) -> T {
public subscript(_ row: Int, _ column: Int) -> T {
get {
data[index(row: row, column: column)]
}
Expand Down Expand Up @@ -196,7 +220,7 @@ extension Array2d {
/// returns a transformed value of the same or of a different type.
/// - Returns: The transformed matrix.
@inlinable
package func map<V: Equatable & AdditiveArithmetic & Sendable>(_ transform: (T) -> (V)) -> Array2d<V> {
public func map<V: Equatable & AdditiveArithmetic & Sendable>(_ transform: (T) -> (V)) -> Array2d<V> {
Array2d<V>(
data: data.map { value in transform(value) },
rowCount: rowCount,
Expand Down
12 changes: 6 additions & 6 deletions Sources/HomomorphicEncryption/Bfv/Bfv+Encode.swift
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ extension Bfv {

@inlinable
// swiftlint:disable:next missing_docs attributes
public static func encode(context: Context<Bfv<T>>, signedValues: some Collection<Scalar.SignedScalar>,
public static func encode(context: Context<Bfv<T>>, signedValues: some Collection<SignedScalar>,
format: EncodeFormat) throws -> CoeffPlaintext
{
try context.encode(signedValues: signedValues, format: format)
Expand All @@ -53,7 +53,7 @@ extension Bfv {
// swiftlint:disable:next missing_docs attributes
public static func encode(
context: Context<Bfv<T>>,
signedValues: some Collection<Scalar.SignedScalar>,
signedValues: some Collection<SignedScalar>,
format: EncodeFormat,
moduliCount: Int?) throws -> EvalPlaintext
{
Expand All @@ -63,25 +63,25 @@ extension Bfv {

@inlinable
// swiftlint:disable:next missing_docs attributes
public static func decodeCoeff<V: ScalarType>(plaintext: CoeffPlaintext, format: EncodeFormat) throws -> [V] {
public static func decodeCoeff(plaintext: CoeffPlaintext, format: EncodeFormat) throws -> [Scalar] {
try plaintext.context.decode(plaintext: plaintext, format: format)
}

@inlinable
// swiftlint:disable:next missing_docs attributes
public static func decodeCoeff<V: SignedScalarType>(plaintext: CoeffPlaintext, format: EncodeFormat) throws -> [V] {
public static func decodeCoeff(plaintext: CoeffPlaintext, format: EncodeFormat) throws -> [SignedScalar] {
try plaintext.context.decode(plaintext: plaintext, format: format)
}

@inlinable
// swiftlint:disable:next missing_docs attributes
public static func decodeEval<V: ScalarType>(plaintext: EvalPlaintext, format: EncodeFormat) throws -> [V] {
public static func decodeEval(plaintext: EvalPlaintext, format: EncodeFormat) throws -> [Scalar] {
try plaintext.convertToCoeffFormat().decode(format: format)
}

@inlinable
// swiftlint:disable:next missing_docs attributes
public static func decodeEval<V: SignedScalarType>(plaintext: EvalPlaintext, format: EncodeFormat) throws -> [V] {
public static func decodeEval(plaintext: EvalPlaintext, format: EncodeFormat) throws -> [SignedScalar] {
try plaintext.convertToCoeffFormat().decode(format: format)
}
}
35 changes: 15 additions & 20 deletions Sources/HomomorphicEncryption/Encoding.swift
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,10 @@ extension Context {
public func encode(signedValues: some Collection<Scheme.SignedScalar>,
format: EncodeFormat) throws -> Plaintext<Scheme, Coeff>
{
let signedModulus = Scheme.Scalar.SignedScalar(plaintextModulus)
let signedModulus = Scheme.SignedScalar(plaintextModulus)
let bounds = -(signedModulus >> 1)...((signedModulus - 1) >> 1)
let centeredValues = try signedValues.map { value in
guard bounds.contains(Scheme.Scalar.SignedScalar(value)) else {
guard bounds.contains(Scheme.SignedScalar(value)) else {
throw HeError.encodingDataOutOfBounds(for: bounds)
}
return Scheme.Scalar(value.centeredToRemainder(modulus: plaintextModulus))
Expand Down Expand Up @@ -102,9 +102,7 @@ extension Context {
/// - Returns: The decoded values.
/// - Throws: Error upon failure to decode.
@inlinable
func decode<T: ScalarType>(plaintext: Plaintext<Scheme, Coeff>,
format: EncodeFormat) throws -> [T]
{
func decode(plaintext: Plaintext<Scheme, Coeff>, format: EncodeFormat) throws -> [Scheme.Scalar] {
switch format {
case .coefficient:
return decodeCoefficient(plaintext: plaintext)
Expand All @@ -121,10 +119,10 @@ extension Context {
/// - Returns: The decoded signed values.
/// - Throws: Error upon failure to decode.
@inlinable
func decode<T: SignedScalarType>(plaintext: Plaintext<Scheme, Coeff>, format: EncodeFormat) throws -> [T] {
func decode(plaintext: Plaintext<Scheme, Coeff>, format: EncodeFormat) throws -> [Scheme.SignedScalar] {
let unsignedValues: [Scheme.Scalar] = try decode(plaintext: plaintext, format: format)
return unsignedValues.map { value in
T(value.remainderToCentered(modulus: plaintextModulus))
value.remainderToCentered(modulus: plaintextModulus)
}
}

Expand All @@ -136,7 +134,7 @@ extension Context {
/// - Returns: The decoded signed values.
/// - Throws: Error upon failure to decode.
@inlinable
func decode<T: SignedScalarType>(plaintext: Plaintext<Scheme, Eval>, format: EncodeFormat) throws -> [T] {
func decode(plaintext: Plaintext<Scheme, Eval>, format: EncodeFormat) throws -> [Scheme.SignedScalar] {
try Scheme.decodeEval(plaintext: plaintext, format: format)
}

Expand Down Expand Up @@ -167,12 +165,11 @@ extension Context {
if values.isEmpty {
return Plaintext<Scheme, Coeff>(context: self, poly: PolyRq.zero(context: plaintextContext))
}
let polyDegree = plaintextContext.degree
var array: Array2d<Scheme.Scalar> = Array2d(array: Array2d(
data: Array(values),
rowCount: 1,
columnCount: values.count))
array.resizeColumn(newColumnCount: polyDegree, defaultValue: Scheme.Scalar(0))
var valuesArray = Array(values)
if valuesArray.count < degree {
valuesArray.append(contentsOf: repeatElement(0, count: degree - valuesArray.count))
}
let array: Array2d<Scheme.Scalar> = Array2d(data: valuesArray, rowCount: 1, columnCount: valuesArray.count)
return Plaintext<Scheme, Coeff>(
context: self,
poly: PolyRq(context: plaintextContext, data: array))
Expand All @@ -185,10 +182,8 @@ extension Context {
/// - Parameter plaintext: Plaintext to decode.
/// - Returns: The decoded plaintext values, each in `[0, t - 1]` for plaintext modulus `t`.
@inlinable
func decodeCoefficient<T: ScalarType>(plaintext: Plaintext<Scheme, Coeff>)
-> [T]
{
Array2d(array: plaintext.poly.data).data
func decodeCoefficient(plaintext: Plaintext<Scheme, Coeff>) -> [Scheme.Scalar] {
plaintext.poly.data.data
}
}

Expand Down Expand Up @@ -233,13 +228,13 @@ extension Context {
}

@inlinable
func decodeSimd<T: ScalarType>(plaintext: Plaintext<Scheme, Coeff>) throws -> [T] {
func decodeSimd(plaintext: Plaintext<Scheme, Coeff>) throws -> [Scheme.Scalar] {
guard !simdEncodingMatrix.isEmpty else {
throw HeError.simdEncodingNotSupported(for: encryptionParameters)
}
let poly = try plaintext.poly.forwardNtt()
return (0..<encryptionParameters.polyDegree).map { index in
T(poly.data[0, simdEncodingMatrix[index]])
poly.data[0, simdEncodingMatrix[index]]
}
}
}
16 changes: 8 additions & 8 deletions Sources/HomomorphicEncryption/HeScheme.swift
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ public protocol HeScheme {
/// - Returns: The decoded values.
/// - Throws: Error upon failure to decode the plaintext.
/// - seealso: ``Plaintext/decode(format:)-i0qm`` for an alternative API.
static func decodeCoeff<T: ScalarType>(plaintext: CoeffPlaintext, format: EncodeFormat) throws -> [T]
static func decodeCoeff(plaintext: CoeffPlaintext, format: EncodeFormat) throws -> [Scalar]

/// Decodes a plaintext in ``Coeff`` format into signed values.
/// - Parameters:
Expand All @@ -238,7 +238,7 @@ public protocol HeScheme {
/// - Returns: The decoded signed values.
/// - Throws: Error upon failure to decode the plaintext.
/// - seealso: ``Plaintext/decode(format:)-5081e`` for an alternative API.
static func decodeCoeff<T: SignedScalarType>(plaintext: CoeffPlaintext, format: EncodeFormat) throws -> [T]
static func decodeCoeff(plaintext: CoeffPlaintext, format: EncodeFormat) throws -> [SignedScalar]

/// Decodes a plaintext in ``Eval`` format.
/// - Parameters:
Expand All @@ -247,7 +247,7 @@ public protocol HeScheme {
/// - Returns: The decoded values.
/// - Throws: Error upon failure to decode the plaintext.
/// - seealso: ``Plaintext/decode(format:)-i0qm`` for an alternative API.
static func decodeEval<T: ScalarType>(plaintext: EvalPlaintext, format: EncodeFormat) throws -> [T]
static func decodeEval(plaintext: EvalPlaintext, format: EncodeFormat) throws -> [Scalar]

/// Decodes a plaintext in ``Eval`` format to signed values.
/// - Parameters:
Expand All @@ -256,7 +256,7 @@ public protocol HeScheme {
/// - Returns: The decoded signed values.
/// - Throws: Error upon failure to decode the plaintext.
/// - seealso: ``Plaintext/decode(format:)-5081e`` for an alternative API.
static func decodeEval<T: SignedScalarType>(plaintext: EvalPlaintext, format: EncodeFormat) throws -> [T]
static func decodeEval(plaintext: EvalPlaintext, format: EncodeFormat) throws -> [SignedScalar]

/// Symmetric secret key encryption of a plaintext.
/// - Parameters:
Expand Down Expand Up @@ -933,9 +933,9 @@ extension HeScheme {
/// - Throws: Error upon failure to decode the plaintext.
/// - seealso: ``Plaintext/decode(format:)-i0qm`` for an alternative API.
@inlinable
public static func decode<T: ScalarType, Format: PolyFormat>(
public static func decode<Format: PolyFormat>(
plaintext: Plaintext<Self, Format>,
format: EncodeFormat) throws -> [T]
format: EncodeFormat) throws -> [Scalar]
{
if Format.self == Coeff.self {
// swiftlint:disable:next force_cast
Expand All @@ -958,9 +958,9 @@ extension HeScheme {
/// - Throws: Error upon failure to decode the plaintext.
/// - seealso: ``Plaintext/decode(format:)-5081e`` for an alternative API.
@inlinable
public static func decode<T: SignedScalarType, Format: PolyFormat>(
public static func decode<Format: PolyFormat>(
plaintext: Plaintext<Self, Format>,
format: EncodeFormat) throws -> [T]
format: EncodeFormat) throws -> [SignedScalar]
{
if Format.self == Coeff.self {
// swiftlint:disable:next force_cast
Expand Down
8 changes: 4 additions & 4 deletions Sources/HomomorphicEncryption/NoOpScheme.swift
Original file line number Diff line number Diff line change
Expand Up @@ -85,19 +85,19 @@ public enum NoOpScheme: HeScheme {
return try EvalPlaintext(context: context, poly: coeffPlaintext.poly.forwardNtt())
}

public static func decodeCoeff<T: ScalarType>(plaintext: CoeffPlaintext, format: EncodeFormat) throws -> [T] {
public static func decodeCoeff(plaintext: CoeffPlaintext, format: EncodeFormat) throws -> [Scalar] {
try plaintext.context.decode(plaintext: plaintext, format: format)
}

public static func decodeEval<T: ScalarType>(plaintext: EvalPlaintext, format: EncodeFormat) throws -> [T] {
public static func decodeEval(plaintext: EvalPlaintext, format: EncodeFormat) throws -> [Scalar] {
try plaintext.inverseNtt().decode(format: format)
}

public static func decodeCoeff<T: SignedScalarType>(plaintext: CoeffPlaintext, format: EncodeFormat) throws -> [T] {
public static func decodeCoeff(plaintext: CoeffPlaintext, format: EncodeFormat) throws -> [SignedScalar] {
try plaintext.context.decode(plaintext: plaintext, format: format)
}

public static func decodeEval<T: SignedScalarType>(plaintext: EvalPlaintext, format: EncodeFormat) throws -> [T] {
public static func decodeEval(plaintext: EvalPlaintext, format: EncodeFormat) throws -> [SignedScalar] {
try plaintext.inverseNtt().decode(format: format)
}

Expand Down
Loading

0 comments on commit 5052bab

Please sign in to comment.