Skip to content

深入循环神经网络

Published: at 03:22 PM

从根本上说,循环神经网络(RNN)通过随着时间的推移递归应用相同的权重集在序列上运行。让我们深入了解它们的架构,理解它们为什么在序列建模任务中异常有效。

RNN背后的数学

基本RNN架构

基本的RNN计算可以表示为:

class RNN:
    def step(self, x, h):
        # Update the hidden state
        h_new = np.tanh(np.dot(W_hh, h) + np.dot(W_xh, x) + b_h)
        # Compute the output vector
        y = np.dot(W_hy, h_new) + b_y
        return h_new, y

以下是详细信息:

  • W_hh: 隐藏到隐藏连接的权重
  • W_xh: 输入到隐藏连接的权重
  • W_hy: 隐藏到输出连接的权重
  • b_h, b_y: 偏置项
  • h: 隐藏状态向量
  • x: 输入向量
  • y: 输出向量

前向传播和随时间反向传播(BPTT)

真正的魔法发生在训练期间。让我们实现一个基本版本:

def forward_backward_pass(inputs, targets, h_prev):
    # Forward pass
    h_states = []
    outputs = []
    h = h_prev
    loss = 0
    
    # Forward pass
    for t in range(len(inputs)):
        h, y = step(inputs[t], h)
        h_states.append(h)
        outputs.append(y)
        loss += -np.log(y[targets[t]])  # Cross-entropy loss
    
    # Backward pass
    dW_hh, dW_xh, dW_hy = np.zeros_like(W_hh), np.zeros_like(W_xh), np.zeros_like(W_hy)
    db_h, db_y = np.zeros_like(b_h), np.zeros_like(b_y)
    dh_next = np.zeros_like(h_states[0])
    
    for t in reversed(range(len(inputs))):
        # Gradient computation goes here
        # This is where BPTT happens
        pass
    
    return loss, dW_hh, dW_xh, dW_hy, db_h, db_y

字符级语言模型:具体示例

让我们实现一个字符级语言模型来演示RNN的能力:

class CharRNN:
    def __init__(self, vocab_size, hidden_size):
        self.hidden_size = hidden_size
        self.vocab_size = vocab_size
        
        # Initialize weights
        self.W_hh = np.random.randn(hidden_size, hidden_size) * 0.01
        self.W_xh = np.random.randn(hidden_size, vocab_size) * 0.01
        self.W_hy = np.random.randn(vocab_size, hidden_size) * 0.01
        self.b_h = np.zeros((hidden_size, 1))
        self.b_y = np.zeros((vocab_size, 1))
    
    def sample(self, h, seed_ix, n):
        x = np.zeros((self.vocab_size, 1))
        x[seed_ix] = 1
        generated = []
        
        for t in range(n):
            h, y = self.step(x, h)
            p = np.exp(y) / np.sum(np.exp(y))
            ix = np.random.choice(range(self.vocab_size), p=p.ravel())
            x = np.zeros((self.vocab_size, 1))
            x[ix] = 1
            generated.append(ix)
            
        return generated

实际中的不合理的有效性

1. 文本生成

当在大型文本语料库上训练时,我们的字符级模型学习:

  • 正确的拼写和单词形成
  • 基本语法和标点
  • 上下文中适当的词汇
  • 类型特定的写作风格

这特别 Remarkable 的是:

  • 模型每次只看到一个字符
  • 它没有内置的单词或语法理解
  • 它从序列中的统计模式学习一切

2. 源代码生成

RNNs甚至可以学习编程语言的语法和模式。例如:

def generate_code(model, seed="def"):
    return model.sample(seed, length=1000)

该模型学习:

  • 正确的缩进
  • 匹配的括号和括号
  • 函数和变量命名约定
  • 基本编程模式

3. 记忆的数学

隐藏状态h充当网络的记忆。在每一个时间步t:

h_t = tanh(W_hh * h_{t-1} + W_xh * x_t + b_h)

这个递归公式允许网络:

  • 维持长期依赖
  • 忘记不相关信息
  • 建立分层表示

高级主题:处理消失梯度

LSTM单元

长短期记忆(LSTM)架构解决了消失梯度问题:

def lstm_step(x, h_prev, c_prev):
    # Gates
    f = sigmoid(W_f.dot(x) + U_f.dot(h_prev) + b_f)
    i = sigmoid(W_i.dot(x) + U_i.dot(h_prev) + b_i)
    o = sigmoid(W_o.dot(x) + U_o.dot(h_prev) + b_o)
    # New memory content
    g = tanh(W_g.dot(x) + U_g.dot(h_prev) + b_g)
    # Update cell state
    c = f * c_prev + i * g
    # Update hidden state
    h = o * tanh(c)
    return h, c

梯度裁剪

为防止爆炸梯度:

def clip_gradients(gradients, max_norm=5):
    norm = np.sqrt(sum(np.sum(grad ** 2) for grad in gradients))
    if norm > max_norm:
        scale = max_norm / norm
        return [grad * scale for grad in gradients]
    return gradients

训练RNNs的实用技巧

  1. 初始化:使用小的随机权重以防止饱和:
W = np.random.randn(n_in, n_out) * np.sqrt(2.0/n_in)
  1. Mini-batch Processing:实现批处理以提高效率:
def process_batch(batch_inputs, batch_size):
    h = np.zeros((batch_size, hidden_size))
    for t in range(seq_length):
        h = step(batch_inputs[t], h)
  1. 学习率调度:实现自适应学习率:
learning_rate = base_lr * decay_rate ** (epoch / decay_steps)

超越简单序列

RNNs可以扩展以处理更复杂的模式:

  1. 双向RNNs:双向处理序列
  2. 深度RNNs:堆叠多个RNN层
  3. 注意力机制:允许网络专注于输入序列的相关部分

结论

RNN的有效性来自于它们通过简单、递归操作学习复杂模式的能力。虽然像Transformers这样的新架构已经出现,但来自RNN的根本性洞察 continues to influence deep learning design。

理解它们的数学基础和实现细节帮助我们欣赏它们为什么工作得如此好以及如何在实践中有效地使用它们。


Previous Post
生成式 AI 入门指南
Next Post
生成式 AI 术语表