-
Notifications
You must be signed in to change notification settings - Fork 13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
genericize CGLS #53
genericize CGLS #53
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,34 @@ | ||
import Foundation | ||
|
||
/// A Euclidean vector space. | ||
public protocol EuclideanVectorSpace: Differentiable, VectorProtocol | ||
where Self.TangentVector == Self | ||
where Self.TangentVector == Self, Self.VectorSpaceScalar == Double | ||
{ | ||
// Note: This is a work in progress. We intend to add more requirements here as we need them. | ||
|
||
/// The squared Euclidean norm of `self`. | ||
var squaredNorm: Double { get } | ||
} | ||
|
||
/// Convenient operators on Euclidean vector spaces. | ||
extension EuclideanVectorSpace { | ||
/// The Euclidean norm of `self`. | ||
public var norm: Double { | ||
return sqrt(squaredNorm) | ||
} | ||
|
||
// Note: We can't have these because Swift type inference is very inefficient | ||
// and these make it too slow. | ||
// | ||
// public static func * (_ lhs: Double, _ rhs: Self) -> Self { | ||
// return lhs.scaled(by: lhs) | ||
// } | ||
// | ||
// public static func * (_ lhs: Self, _ rhs: Double) -> Self { | ||
// return lhs.scaled(by: rhs) | ||
// } | ||
// | ||
// public static func / (_ lhs: Self, _ rhs: Double) -> Self { | ||
// return lhs.scaled(by: 1 / rhs) | ||
// } | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
// Copyright 2020 The SwiftFusion Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
/// An affine function decomposed into its linear and bias components. | ||
public protocol DecomposedAffineFunction { | ||
associatedtype Input: EuclideanVectorSpace | ||
associatedtype Output: EuclideanVectorSpace | ||
|
||
/// Apply the function to `x`. | ||
/// | ||
/// This is equal to `applyLinearForward(x) + bias`. | ||
/// | ||
/// Note: A default implementation is provided, but conforming types may provide a more efficient | ||
/// implementation. | ||
func callAsFunction(_ x: Input) -> Output | ||
|
||
/// The linear component of the affine function. | ||
func applyLinearForward(_ x: Input) -> Output | ||
|
||
/// The linear adjoint of the linear component of the affine function. | ||
func applyLinearAdjoint(_ y: Output) -> Input | ||
|
||
/// The bias component of the affine function. | ||
/// | ||
/// This is equal to `applyLinearForward(Input.zero)`. | ||
var bias: Output { get } | ||
} | ||
|
||
extension DecomposedAffineFunction { | ||
public func callAsFunction(_ x: Input) -> Output { | ||
return applyLinearForward(x) + bias | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -48,11 +48,6 @@ public struct VectorValues: KeyPathIterable { | |
} | ||
} | ||
|
||
/// L2 norm of the VectorValues | ||
var norm: Double { | ||
self._values.map { $0.squared().sum() }.reduce(0.0, { $0 + $1 }) | ||
} | ||
|
||
/// Insert a key value pair | ||
public mutating func insert(_ key: Int, _ val: Vector) { | ||
assert(_indices[key] == nil) | ||
|
@@ -61,26 +56,20 @@ public struct VectorValues: KeyPathIterable { | |
self._values.append(val) | ||
} | ||
|
||
/// VectorValues + Scalar | ||
static func + (_ lhs: Self, _ rhs: Self.ScalarType) -> Self { | ||
var result = lhs | ||
let _ = result._values.indices.map { result._values[$0] += rhs } | ||
return result | ||
} | ||
|
||
/// Scalar * VectorValues | ||
static func * (_ lhs: Self.ScalarType, _ rhs: Self) -> Self { | ||
var result = rhs | ||
let _ = result._values.indices.map { result._values[$0] *= lhs } | ||
return result | ||
} | ||
} | ||
|
||
extension VectorValues: Differentiable { | ||
extension VectorValues: EuclideanVectorSpace { | ||
|
||
// NOTE: Most of these are boilerplate that should be synthesized automatically. However, the | ||
// current synthesis functionality can't deal with the `_indices` property. So we have to | ||
// implement it manually for now. | ||
|
||
// MARK: - Differentiable conformance. | ||
|
||
public typealias TangentVector = Self | ||
} | ||
|
||
extension VectorValues: AdditiveArithmetic { | ||
// MARK: - AdditiveArithmetic conformance. | ||
|
||
public static func += (_ lhs: inout VectorValues, _ rhs: VectorValues) { | ||
for key in rhs.keys { | ||
let rhsVector = rhs[key] | ||
|
@@ -110,6 +99,53 @@ extension VectorValues: AdditiveArithmetic { | |
public static var zero: VectorValues { | ||
return VectorValues() | ||
} | ||
|
||
// MARK: - VectorProtocol conformance | ||
|
||
public typealias VectorValuesSpaceScalar = Double | ||
|
||
public mutating func add(_ x: Double) { | ||
for index in _values.indices { | ||
_values[index] += x | ||
} | ||
} | ||
|
||
public func adding(_ x: Double) -> VectorValues { | ||
var result = self | ||
result.add(x) | ||
return result | ||
} | ||
|
||
public mutating func subtract(_ x: Double) { | ||
for index in _values.indices { | ||
_values[index] -= x | ||
} | ||
} | ||
|
||
public func subtracting(_ x: Double) -> VectorValues { | ||
var result = self | ||
result.subtract(x) | ||
return result | ||
} | ||
|
||
public mutating func scale(by scalar: Double) { | ||
for index in _values.indices { | ||
_values[index] *= scalar | ||
} | ||
} | ||
|
||
public func scaled(by scalar: Double) -> VectorValues { | ||
var result = self | ||
result.scale(by: scalar) | ||
return result | ||
} | ||
|
||
// MARK: - Additional EuclideanVectorSpace requirements. | ||
|
||
public var squaredNorm: Double { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we still need the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
This function currently isn't declared I like to avoid adding the |
||
self._values.map { $0.squared().sum() }.reduce(0.0, { $0 + $1 }) | ||
} | ||
|
||
} | ||
|
||
extension VectorValues: CustomStringConvertible { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you explain briefly how these are automatically implemented? Curious on the workings...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah! We have automatic synthesis for
AdditiveArithmetic
andVectorProtocol
that looks at all the fields of the struct. If all the fields conform toAdditiveArithmetic
orVectorProtocol
, it automatically implements the requirements by applying the functions to all the members.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The single field
Array<Error>.DifferentiableView
conforms toAdditveArithmetic
andVectorProtocol
because:Error
(akaVector
) conforms to them.Array.DifferentiableView
conforms toAdditiveArithmetic
when itsElement
does: https://github.com/apple/swift/blob/4bc72aedb2b32c33ed8e2ec241615fc890b60002/stdlib/public/Differentiation/ArrayDifferentiation.swift#L98Array.DifferentiableView
conforms toVectorProtocol
when itsElement
does: https://github.com/tensorflow/swift-apis/blob/a8a24c46e478ce50c1c9a7718a41eec453e5b670/Sources/TensorFlow/StdlibExtensions.swift#L224