From e01c44df1e65a9cdf687e9c42362c439a5b57bea Mon Sep 17 00:00:00 2001 From: Jayoung Ryu Date: Tue, 4 Oct 2022 23:04:29 -0400 Subject: [PATCH] add FixOperator to fix RHS node --- torchbiggraph/operators.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/torchbiggraph/operators.py b/torchbiggraph/operators.py index 3749f5d..1ac4ff0 100755 --- a/torchbiggraph/operators.py +++ b/torchbiggraph/operators.py @@ -56,6 +56,15 @@ def forward(self, embeddings: FloatTensorType) -> FloatTensorType: def get_operator_params_for_reg(self) -> Optional[FloatTensorType]: return None +@OPERATORS.register_as("fix") +class FixOperator(AbstractOperator): + # Detach node tensor that the loss isn't propagated to the node embedding. + def forward(self, embeddings: FloatTensorType) -> FloatTensorType: + match_shape(embeddings, ..., self.dim) + return embeddings.clone().detach() + + def get_operator_params_for_reg(self) -> Optional[FloatTensorType]: + return None @OPERATORS.register_as("diagonal") class DiagonalOperator(AbstractOperator):