Skip to content

Commit

Permalink
create a cross attention only attention layer (CrossAttender)
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Dec 14, 2020
1 parent 7b80b96 commit cf4035d
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 11 deletions.
12 changes: 3 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -519,22 +519,16 @@ Cross Attention

```python
import torch
from x_transformers import Encoder
from x_transformers import Encoder, CrossAttender

enc = Encoder(dim = 512, depth = 6)

cross_attn = Encoder(
dim = 512,
depth = 6,
cross_attend = True,
only_cross = True
)
model = CrossAttender(dim = 512, depth = 6)

nodes = torch.randn(1, 1, 512)
neighbors = torch.randn(1, 5, 512)

encoded_neighbors = enc(neighbors)
cross_attn(nodes, context = encoded_neighbors) # (1, 1, 512)
model(nodes, context = encoded_neighbors) # (1, 1, 512)
```

## Citations
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'x-transformers',
packages = find_packages(exclude=['examples']),
version = '0.3.3',
version = '0.3.4',
license='MIT',
description = 'X-Transformers - Pytorch',
author = 'Phil Wang',
Expand Down
2 changes: 1 addition & 1 deletion x_transformers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from x_transformers.x_transformers import XTransformer, Encoder, Decoder, TransformerWrapper, ViTransformerWrapper
from x_transformers.x_transformers import XTransformer, Encoder, Decoder, CrossAttender, TransformerWrapper, ViTransformerWrapper
from x_transformers.funnel import FunnelEncoder
from x_transformers.autoregressive_wrapper import AutoregressiveWrapper
4 changes: 4 additions & 0 deletions x_transformers/x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,10 @@ def __init__(self, **kwargs):
assert 'causal' not in kwargs, 'cannot set causality on decoder'
super().__init__(causal = True, **kwargs)

class CrossAttender(AttentionLayers):
def __init__(self, **kwargs):
super().__init__(cross_attend = True, only_cross = True, **kwargs)

class ViTransformerWrapper(nn.Module):
def __init__(
self,
Expand Down

0 comments on commit cf4035d

Please sign in to comment.