Transformer fundamentals
 
Loading...
Searching...
No Matches
working_gpt.Head Class Reference

one head of self-attention More...

Inheritance diagram for working_gpt.Head:

Public Member Functions

 __init__ (self, head_size)
 
 forward (self, x)
 

Public Attributes

 key = nn.Linear(n_embd, head_size, bias=False)
 
 query = nn.Linear(n_embd, head_size, bias=False)
 
 value = nn.Linear(n_embd, head_size, bias=False)
 
 dropout = nn.Dropout(dropout)
 

Detailed Description

one head of self-attention

Definition at line 80 of file working_gpt.py.

Constructor & Destructor Documentation

◆ __init__()

working_gpt.Head.__init__ ( self,
head_size )

Definition at line 83 of file working_gpt.py.

83 def __init__(self, head_size):
84 super().__init__()
85 self.key = nn.Linear(n_embd, head_size, bias=False)
86 self.query = nn.Linear(n_embd, head_size, bias=False)
87 self.value = nn.Linear(n_embd, head_size, bias=False)
88 self.register_buffer("tril", torch.tril(torch.ones(block_size, block_size)))
89
90 self.dropout = nn.Dropout(dropout)
91

References __init__().

Referenced by __init__().

Member Function Documentation

◆ forward()

working_gpt.Head.forward ( self,
x )

Definition at line 92 of file working_gpt.py.

92 def forward(self, x):
93 # input of size (batch, time-step, channels)
94 # output of size (batch, time-step, head size)
95 B, T, C = x.shape
96 k = self.key(x) # (B,T,hs)
97 q = self.query(x) # (B,T,hs)
98 # compute attention scores ("affinities")
99 wei = (
100 q @ k.transpose(-2, -1) * k.shape[-1] ** -0.5
101 ) # (B, T, hs) @ (B, hs, T) -> (B, T, T)
102 wei = wei.masked_fill(self.tril[:T, :T] == 0, float("-inf")) # (B, T, T)
103 wei = F.softmax(wei, dim=-1) # (B, T, T)
104 wei = self.dropout(wei)
105 # perform the weighted aggregation of the values
106 v = self.value(x) # (B,T,hs)
107 out = wei @ v # (B, T, T) @ (B, T, hs) -> (B, T, hs)
108 return out
109
110

References dropout, key, query, and value.

Member Data Documentation

◆ dropout

working_gpt.Head.dropout = nn.Dropout(dropout)

Definition at line 90 of file working_gpt.py.

Referenced by forward(), and working_gpt.MultiHeadAttention.forward().

◆ key

working_gpt.Head.key = nn.Linear(n_embd, head_size, bias=False)

Definition at line 85 of file working_gpt.py.

Referenced by forward().

◆ query

working_gpt.Head.query = nn.Linear(n_embd, head_size, bias=False)

Definition at line 86 of file working_gpt.py.

Referenced by forward().

◆ value

working_gpt.Head.value = nn.Linear(n_embd, head_size, bias=False)

Definition at line 87 of file working_gpt.py.

Referenced by forward().


The documentation for this class was generated from the following file: