Skip to content

Commit

Permalink
Workaround for sr-13263 (#1)
Browse files Browse the repository at this point in the history
* Update Normalization.swift
  • Loading branch information
philipturner authored Jan 13, 2022
1 parent 652f815 commit 335f94d
Showing 1 changed file with 30 additions and 8 deletions.
38 changes: 30 additions & 8 deletions Sources/TensorFlow/Layers/Normalization.swift
Original file line number Diff line number Diff line change
Expand Up @@ -105,21 +105,43 @@ public struct BatchNorm<Scalar: TensorFlowFloatingPoint>: Layer {
precondition(
input.shape[positiveAxis] == offset.shape[0],
"The number of features of the input and the offset doesn't match.")
var (offset, scale) = {x in (x.offset, x.scale) }(self)
if positiveAxis != input.rank - 1 {
var broadcastShape = TensorShape([Int](repeating: 1, count: input.rank))
broadcastShape[positiveAxis] = input.shape[positiveAxis]
offset = offset.reshaped(to: broadcastShape)

scale = scale.reshaped(to: broadcastShape)
}
// var (offset, scale) = {x in (x.offset, x.scale) }(self)
// if positiveAxis != input.rank - 1 {
// var broadcastShape = TensorShape([Int](repeating: 1, count: input.rank))
// broadcastShape[positiveAxis] = input.shape[positiveAxis]
// offset = offset.reshaped(to: broadcastShape)
// scale = scale.reshaped(to: broadcastShape)
// }
let offsetOriginal = self.offset
let scaleOriginal = self.scale
let (offset, scale) = Self._sr13263workaround(offset: offsetOriginal,
scale: scaleOriginal,
input: input,
positiveAxis: positiveAxis)
switch Context.local.learningPhase {
case .training:
return doTraining(input, offset: offset, scale: scale, axis: positiveAxis)
case .inference:
return doInference(input, offset: offset, scale: scale)
}
}

@inline(never)
@differentiable(reverse) // if the function is `public` or `internal`, the compiler crashes
private static func _sr13263workaround(
offset: Tensor<Scalar>,
scale: Tensor<Scalar>,
input: Tensor<Scalar>,
positiveAxis: Int
) -> (Tensor<Scalar>, Tensor<Scalar>) {
if positiveAxis != input.rank - 1 {
var broadcastShape = TensorShape([Int](repeating: 1, count: input.rank))
broadcastShape[positiveAxis] = input.shape[positiveAxis]
return (offset.reshaped(to: broadcastShape), scale.reshaped(to: broadcastShape))
} else {
return (offset, scale)
}
}

private func doTraining(
_ input: Tensor<Scalar>, offset: Tensor<Scalar>, scale: Tensor<Scalar>, axis: Int
Expand Down

0 comments on commit 335f94d

Please sign in to comment.