3from torch.nn
import functional
as F
12if torch.backends.mps.is_available()
and torch.backends.mps.is_built():
15elif torch.cuda.is_available():
28torch.manual_seed(1337)
31with open(
"input.txt",
"r", encoding=
"utf-8")
as f:
35chars = sorted(list(set(text)))
36vocab_size = len(chars)
38stoi = {ch: i
for i, ch
in enumerate(chars)}
39itos = {i: ch
for i, ch
in enumerate(chars)}
43decode =
lambda l:
"".join(
48data = torch.tensor(
encode(text), dtype=torch.long)
49n = int(0.9 * len(data))
57 data = train_data
if split ==
"train" else val_data
58 ix = torch.randint(len(data) - block_size, (batch_size,))
59 x = torch.stack([data[i : i + block_size]
for i
in ix])
60 y = torch.stack([data[i + 1 : i + block_size + 1]
for i
in ix])
61 x, y = x.to(device), y.to(device)
69 for split
in [
"train",
"val"]:
70 losses = torch.zeros(eval_iters)
71 for k
in range(eval_iters):
73 logits, loss =
model(X, Y)
74 losses[k] = loss.item()
75 out[split] = losses.mean()
81 """one head of self-attention"""
85 self.
key = nn.Linear(n_embd, head_size, bias=
False)
86 self.
query = nn.Linear(n_embd, head_size, bias=
False)
87 self.
value = nn.Linear(n_embd, head_size, bias=
False)
88 self.register_buffer(
"tril", torch.tril(torch.ones(block_size, block_size)))
100 q @ k.transpose(-2, -1) * k.shape[-1] ** -0.5
102 wei = wei.masked_fill(self.tril[:T, :T] == 0, float(
"-inf"))
103 wei = F.softmax(wei, dim=-1)
112 """multiple heads of self-attention in parallel"""
116 self.
heads = nn.ModuleList([
Head(head_size)
for _
in range(num_heads)])
117 self.
proj = nn.Linear(head_size * num_heads, n_embd)
121 out = torch.cat([h(x)
for h
in self.
heads], dim=-1)
127 """a simple linear layer followed by a non-linearity"""
132 nn.Linear(n_embd, 4 * n_embd),
134 nn.Linear(4 * n_embd, n_embd),
143 """Transformer block: communication followed by computation"""
148 head_size = n_embd // n_head
151 self.
ln1 = nn.LayerNorm(n_embd)
152 self.
ln2 = nn.LayerNorm(n_embd)
155 x = x + self.
sa(self.
ln1(x))
168 *[
Block(n_embd, n_head=n_head)
for _
in range(n_layer)]
170 self.
ln_f = nn.LayerNorm(n_embd)
177 if isinstance(module, nn.Linear):
178 torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
179 if module.bias
is not None:
180 torch.nn.init.zeros_(module.bias)
181 elif isinstance(module, nn.Embedding):
182 torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
190 x = tok_emb + pos_emb
198 B, T, C = logits.shape
199 logits = logits.view(B * T, C)
200 targets = targets.view(B * T)
201 loss = F.cross_entropy(logits, targets)
207 for _
in range(max_new_tokens):
209 idx_cond = idx[:, -block_size:]
211 logits, loss = self(idx_cond)
213 logits = logits[:, -1, :]
215 probs = F.softmax(logits, dim=-1)
217 idx_next = torch.multinomial(probs, num_samples=1)
219 idx = torch.cat((idx, idx_next), dim=1)
226print(sum(p.numel()
for p
in m.parameters()) / 1e6,
"M parameters")
229optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
231for iter
in range(max_iters):
234 if iter % eval_interval == 0
or iter == max_iters - 1:
237 f
"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}"
245 optimizer.zero_grad(set_to_none=
True)
250context = torch.zeros((1, 1), dtype=torch.long, device=device)
251print(
decode(m.generate(context, max_new_tokens=500)[0].tolist()))
Transformer block: communication followed by computation.
__init__(self, n_embd, n_head)
a simple linear layer followed by a non-linearity
generate(self, idx, max_new_tokens)
forward(self, idx, targets=None)
one head of self-attention
__init__(self, head_size)
multiple heads of self-attention in parallel
__init__(self, num_heads, head_size)