Transformer fundamentals
 
Loading...
Searching...
No Matches
working_gpt.MultiHeadAttention Class Reference

multiple heads of self-attention in parallel More...

Inheritance diagram for working_gpt.MultiHeadAttention:

Public Member Functions

 __init__ (self, num_heads, head_size)
 
 forward (self, x)
 

Public Attributes

 heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
 
 proj = nn.Linear(head_size * num_heads, n_embd)
 
 dropout = nn.Dropout(dropout)
 

Detailed Description

multiple heads of self-attention in parallel

Definition at line 111 of file working_gpt.py.

Constructor & Destructor Documentation

◆ __init__()

working_gpt.MultiHeadAttention.__init__ ( self,
num_heads,
head_size )

Definition at line 114 of file working_gpt.py.

114 def __init__(self, num_heads, head_size):
115 super().__init__()
116 self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
117 self.proj = nn.Linear(head_size * num_heads, n_embd)
118 self.dropout = nn.Dropout(dropout)
119

References __init__().

Referenced by __init__().

Member Function Documentation

◆ forward()

working_gpt.MultiHeadAttention.forward ( self,
x )

Definition at line 120 of file working_gpt.py.

120 def forward(self, x):
121 out = torch.cat([h(x) for h in self.heads], dim=-1)
122 out = self.dropout(self.proj(out))
123 return out
124
125

References working_gpt.Head.dropout, dropout, heads, and proj.

Member Data Documentation

◆ dropout

working_gpt.MultiHeadAttention.dropout = nn.Dropout(dropout)

Definition at line 118 of file working_gpt.py.

Referenced by forward().

◆ heads

working_gpt.MultiHeadAttention.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])

Definition at line 116 of file working_gpt.py.

Referenced by forward().

◆ proj

working_gpt.MultiHeadAttention.proj = nn.Linear(head_size * num_heads, n_embd)

Definition at line 117 of file working_gpt.py.

Referenced by forward().


The documentation for this class was generated from the following file: