This is a PyTorch implementation of the paper Primer: Searching for Efficient Transformers for Language Modeling.

The authors do an evolutionary search for transformer architectures. They name the architecture found using the search Primer (PRIMitives searched transformER). **Primer EZ** is the architecture with the two most robust modifications in Primer compared to the original transformer. Primer EZ trains a lot faster than the vanilla transformer.

The most effective modification found by the search is using a square ReLU instead of ReLU in the position-wise feedforward module.

$y=max(x,0)_{2}$

The next effective modification is a depth-wise $3×1$ convolution after multi-head projection for queries, keys, and values. The convolution is along the sequence dimension and per channel (depth-wise). To be clear, if the number of channels in each head is $d_{k}$ the convolution will have $1×3$ kernels for each of the $d_{k}$ channels.

Here is the experiment code, for Primer EZ.

```
40import torch
41from torch import nn
42
43from labml_helpers.module import Module
44from labml_nn.transformers import MultiHeadAttention
```

$y=max(x,0)_{2}$

Squared ReLU is used as the activation function in the position wise feedforward module.

`47class SquaredReLU(Module):`

```
57 def __init__(self):
58 super().__init__()
59 self.relu = nn.ReLU()
```

`61 def forward(self, x: torch.Tensor):`

Apply ReLU

`63 x = self.relu(x)`

Square it

`65 return x * x`

`68class SpatialDepthWiseConvolution(Module):`

`d_k`

is the number of channels in each head

`73 def __init__(self, d_k: int, kernel_size: int = 3):`

```
77 super().__init__()
78 self.kernel_size = kernel_size
```

We use PyTorch's `Conv1d`

module. We set the number of groups to be equal to the number of channels so that it does a separate convolution (with different kernels) for each channel. We add padding to both sides and later crop the right most `kernel_size - 1`

results

```
83 self.conv = nn.Conv1d(in_channels=d_k, out_channels=d_k,
84 kernel_size=(kernel_size,), padding=(kernel_size - 1,), groups=d_k)
```

`x`

has shape `[seq_len, batch_size, heads, d_k]`

`86 def forward(self, x: torch.Tensor):`

Get the shape

`92 seq_len, batch_size, heads, d_k = x.shape`

Permute to `[batch_size, heads, d_k, seq_len]`

`94 x = x.permute(1, 2, 3, 0)`

Change the shape to `[batch_size * heads, d_k, seq_len]`

`96 x = x.view(batch_size * heads, d_k, seq_len)`

1D convolution accepts input of the form `[N, channels, sequence]`

`99 x = self.conv(x)`

Crop the right most `kernel_size - 1`

results since we padded both sides

`101 x = x[:, :, :-(self.kernel_size - 1)]`

Reshape to `[batch_size, heads, d_k, seq_len]`

`103 x = x.view(batch_size, heads, d_k, seq_len)`

Permute to `[seq_len, batch_size, heads, d_k]`

`105 x = x.permute(3, 0, 1, 2)`

`108 return x`

We extend our original implementation of Multi-Head Attention and add the spatial depth-wise convolution to query, key and value projections.

`111class MultiDConvHeadAttention(MultiHeadAttention):`

```
119 def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1):
120 super().__init__(heads, d_model, dropout_prob)
```

Multi-Head Attention will create query, key and value projection modules `self.query`

, `self.key`

, and `self.value`

.

We combine a spatial depth-wise convolution layer to each of them and replace `self.query`

, `self.key`

, and `self.value`

.

📝 *We feel this cleaner implementation is easier to understand since it clearly shows the difference between this and vanilla transformer multi-head attention*.

```
130 self.query = nn.Sequential(self.query, SpatialDepthWiseConvolution(self.d_k))
131 self.key = nn.Sequential(self.key, SpatialDepthWiseConvolution(self.d_k))
132 self.value = nn.Sequential(self.value, SpatialDepthWiseConvolution(self.d_k))
```