Skip to content

Commit

Permalink
Update src/adapters/methods/reft.py
Browse files Browse the repository at this point in the history
Co-authored-by: Leon Engländer <leon.englaender@gmail.com>
  • Loading branch information
calpt and lenglaender authored Jan 18, 2025
1 parent a66f4ae commit e964973
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion src/adapters/methods/reft.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,12 @@ def __init__(
projection = nn.Linear(in_dim, r_dim, bias=False, dtype=dtype)
if orthogonal:
# orthogonal is not implemented for half precision
projection = projection.to(dtype=torch.float32)
if dtype in [torch.float16, torch.bfloat16]:
warnings.warn(
"Orthogonal parametrization is not supported for half precision dtypes. Converting REFT projection layer to float32.",
UserWarning
)
projection = projection.to(dtype=torch.float32)
self.projection = nn.utils.parametrizations.orthogonal(projection)
else:
self.projection = projection
Expand Down

0 comments on commit e964973

Please sign in to comment.