Skip to content

Commit

Permalink
Fix DeepFM
Browse files Browse the repository at this point in the history
  • Loading branch information
zhenghaoz committed Nov 6, 2024
1 parent e0c3290 commit 42cd63c
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 16 deletions.
16 changes: 8 additions & 8 deletions common/nn/tensor.go
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ func (t *Tensor) Backward() {
inputs, output := op.inputsAndOutput()
grads := op.backward(output.grad)
// Clear gradient of non-leaf tensor
output.grad = nil
//output.grad = nil
for i := range grads {
if !slices.Equal(inputs[i].shape, grads[i].shape) {
panic(fmt.Sprintf("%s: shape %v does not match shape %v", op.String(), inputs[i].shape, grads[i].shape))
Expand All @@ -229,7 +229,7 @@ func (t *Tensor) Backward() {
ops = append(ops, inputs[i].op)
} else if !inputs[i].requireGrad {
// Clear gradient if the leaf tensor does not require gradient
inputs[i].grad = nil
//inputs[i].grad = nil
}
}
}
Expand Down Expand Up @@ -366,7 +366,7 @@ func (t *Tensor) matMul(other *Tensor, transpose1, transpose2 bool) *Tensor {
panic("matMul requires 2-D tensors")
}
if t.shape[1] != other.shape[0] {
panic("matMul requires the shapes of tensors are compatible")
panic(fmt.Sprintf("matMul requires the shapes of tensors are compatible, but got %v and %v", t.shape, other.shape))
}
m, n, p := t.shape[0], t.shape[1], other.shape[1]
result := make([]float32, m*p)
Expand All @@ -385,7 +385,7 @@ func (t *Tensor) matMul(other *Tensor, transpose1, transpose2 bool) *Tensor {
panic("matMul requires 2-D tensors")
}
if t.shape[0] != other.shape[0] {
panic("matMul requires the shapes of tensors are compatible")
panic(fmt.Sprintf("matMul requires the shapes of tensors are compatible, but got %v and %v", t.shape, other.shape))
}
m, n, p := t.shape[1], t.shape[0], other.shape[1]
result := make([]float32, m*p)
Expand All @@ -404,7 +404,7 @@ func (t *Tensor) matMul(other *Tensor, transpose1, transpose2 bool) *Tensor {
panic("matMul requires 2-D tensors")
}
if t.shape[1] != other.shape[1] {
panic("matMul requires the shapes of tensors are compatible")
panic(fmt.Sprintf("matMul requires the shapes of tensors are compatible, but got %v and %v", t.shape, other.shape))
}
m, n, p := t.shape[0], t.shape[1], other.shape[0]
result := make([]float32, m*p)
Expand All @@ -423,7 +423,7 @@ func (t *Tensor) matMul(other *Tensor, transpose1, transpose2 bool) *Tensor {
panic("matMul requires 2-D tensors")
}
if t.shape[0] != other.shape[1] {
panic("matMul requires the shapes of tensors are compatible")
panic(fmt.Sprintf("matMul requires the shapes of tensors are compatible, but got %v and %v", t.shape, other.shape))
}
m, n, p := t.shape[1], t.shape[0], other.shape[0]
result := make([]float32, m*p)
Expand Down Expand Up @@ -533,11 +533,11 @@ func (t *Tensor) batchMatMul(other *Tensor, transpose1, transpose2 bool) *Tensor
func (t *Tensor) maximum(other *Tensor) {
if other.IsScalar() {
for i := range t.data {
t.data[i] = math32.Max(t.data[i], other.data[0])
t.data[i] = max(t.data[i], other.data[0])
}
} else {
for i := range t.data {
t.data[i] = math32.Max(t.data[i], other.data[i])
t.data[i] = max(t.data[i], other.data[i])
}
}
}
Expand Down
14 changes: 6 additions & 8 deletions model/click/deepfm_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ func (fm *DeepFMV2) Fit(ctx context.Context, trainSet *Dataset, testSet *Dataset
}
indices, values, target := fm.convertToTensors(x, y)

//optimizer := nn.NewAdam(fm.Parameters(), fm.lr)
optimizer := nn.NewAdam(fm.Parameters(), fm.lr)
for epoch := 1; epoch <= fm.nEpochs; epoch++ {
fitStart := time.Now()
cost := float32(0)
Expand All @@ -208,13 +208,11 @@ func (fm *DeepFMV2) Fit(ctx context.Context, trainSet *Dataset, testSet *Dataset
batchValues := values.Slice(i, i+fm.batchSize)
batchTarget := target.Slice(i, i+fm.batchSize)
batchOutput := fm.Forward(batchIndices, batchValues)
batchOutput.Backward()
_ = batchTarget
//batchLoss := nn.BCEWithLogits(batchTarget, batchOutput)
//cost += batchLoss.Data()[0]
//optimizer.ZeroGrad()
//batchLoss.Backward()
//optimizer.Step()
batchLoss := nn.BCEWithLogits(batchTarget, batchOutput)
cost += batchLoss.Data()[0]
optimizer.ZeroGrad()
batchLoss.Backward()
optimizer.Step()
}

fitTime := time.Since(fitStart)
Expand Down

0 comments on commit 42cd63c

Please sign in to comment.