类演示水平的量级。RLPDReinforcement Learning with Prior Data是一种基于 off-policy 的 actor-critic强化学习算法借鉴了 soft-actor-critic 等时序差分算法的成功经验但为满足上述需求做出了一些关键修改它其实就是 SAC 之前的数据Prior Data 极高的更新频率High UTD。注本系列的最终目标是“通过一系列相关项目/算法的解读来深入学习/分析/反推 LWDLearning while Deploying这篇论文的机理和可能实现”。之所以从SERL入手是因为 SERLHIL-SERLSOP没有开源都是罗剑岚博士的一系列论文可以从中管窥作者的思路脉络。本文依然是从工程/论文进行反推还请读者不吝指出问题多谢。0x01 基础 背景1.1 总体流程图SERL 的总体流程图如下其中SACSoft Actor-Critic是算法底座 是整个系统的引擎。RLPD (RL with Prior Data) 是性能加速器是 SERL 实现20 分钟学会的关键它通过暴力更新和不忘初心来榨取每一条数据的价值。High UTD迫使模型对有限的样本进行 深度研读从而在极短的物理时间内捕捉到成功的信号。1.2 面临的问题在许多场景中强化学习的优异表现依赖于与环境进行大量的在线交互这通常通过使用模拟器来实现。然而在实际问题中常常面临样本获取成本高昂的情况。此外奖励信号稀疏且高维的状态和动作空间往往使这一问题更加严重。一些先前的研究致力于通过预训练利用这些数据而其他方法则在在线训练时引入约束以应对分布转移问题。然而每种方法都有其缺点例如需要额外的训练时间和超参数或者在行为策略之外的提升有限。0x02 RLPD 基础RLPD关注的是是否可以在 在线学习时直接应用现有的离策略方法以充分利用离线数据。在每一步训练中RLPD 在先验(离线)数据和on-policy数据之间等概率采样以形成一个训练批次即“对称采样”即每个批次有50%的数据来自在线回放缓冲区另外50%来自离线数据缓冲区先验数据。我们对比如下原生纯靠自己试SERL在 SAC 的 Batch 里塞进人类演示数据相当于给 SAC 考试时递了一张带有参考答案的纸让它不用从头瞎猜那RLPD这四个字母贵在哪里秘密不在怎么算Loss Function而在喂什么Batch Composition。RLPD 50/50 抽样 强行保证每个 Batch 都有 128 条 Demo。这相当于给机器人装了一个强制记忆模块让它每一秒钟都在看正确答案。2.1 RLPD 的三大支柱High UTDUpdate-to-DataSERL 每跑一步环境就进行 20 次网络更新。这种暴力刷题让每一条物理数据都被反复压榨是 20 分钟收敛的硬件级保证。Layer NormalizationLN问题在普通的 SAC 中, 网络通常就是纯 MLP。但在 RLPD 中, 由于我们要做 High UTD (高频更新), Q 值的估算非常容易发散。解决方案因此RLPD 会在 Critic 网络 (甚至 Actor 网络) 的每一个隐藏层之后, 都加上 LayerNorm。直观理解高 UTD 会产生巨大的梯度冲击。LN 就给高速赛车装上强力悬挂确保每一层神经元的输出在 20 倍更新频率下依然稳定不至于飞出去。50/50 混合采样做法在 update 过程中, 我们不再只从 replay_buffer 抽数据, 而是将演示数据Demo与在线数据Online强行等比例混合确保智能体在探索新坑时始终能看到正确路标。抽 128 个在线数据 (自己跑出来的)。抽 128 个先验数据 (人类演示的 Demo)。把它们拼成一个 256 的 Batch 进行训练。为什么有效?演示数据告诉智能体正确的路长啥样, 在线数据告诉智能体这里的坑别踩。两者结合, 收敛速度能提升数倍。2.2 Prior Data 策略Prior Data 策略决定了演示数据与在线数据的黄金采样配比。真机强化学习最怕冷启动— 机器人像没头苍蝇一样乱撞。SERL 引入了 RLPDReinforcement Learning with Prior Data机制数据混洗在训练的每一个 Batch 中系统会强制性地混合50% 的演示数据这是由人类老师录制好的标准答案50% 的在线数据这是机器人自己折腾出来的实战记录算法价值这种混合采样确保了模型在进化的每一秒都在不断对比正确答案与自己的尝试。它解决了强化学习初期的探索困境让机器人即使在完全没拿高分的情况下也能通过模仿专家数据来迅速建立起任务的初步认知。我们可以这样理解prior data 像老师给的标准答案online replay 像学生自己的练习记录每次训练都同时看标准答案和自己的错题策略就不会偏离任务太远这也是 SERL 样本效率高的关键之一。它不是让机器人从零开始乱试而是在 demonstrations 的引导下进行强化学习微调。论文中明确写到每次更新使用 sample-based approximation其中 half of the samples drawn from prior datahalf drawn from replay buffer。2.3 算法流程图下面是根据 rlpd.py 的代码逻辑整理的 RLPD 训练流程逻辑图。这个图展示了从数据采样到网络更新的完整路径, 特别标注了 RLPD 相对于普通 SAC 的核心改进点 (如 BC Loss 和 Pessimistic Backup)。2.4 代码细节RLPD 关键组件说明Ensemble Q (10 Qs): 使用 10 个 Q 网络而非 2 个, 增加评估的多样性。rho (Pessimism): 通过 均值 - rho * 标准差 来实现对不确定区域的惩罚。BC Loss: 在 Actor 更新中加入行为克隆, 强迫智能体初始阶段不要偏离演示数据。High UTD (Scan): JAX 的lax.scan允许在一个硬件循环内执行 20 次上述流程。LayerNorm (代码内部) 在所有隐藏层后强制执行归一化, 支撑高频更新。结合 rlpd.py 代码的深度解读:关于 rho 的计算:next_qs self.network.select(target_critic)(batch[next_observations][..., -1, :], next_actions) next_q next_qs.mean(axis0) - next_qs.std(axis0) * self.config[rho]这是 RLPD 的精髓。普通的 SAC 是 min(Q1, Q2), 而这里是用标准差来量化不确定性。如果 10 个 Q 网络对某个状态动作意见不统一 (std 大), next_q 就会被压得很低。关于 BC Lossbc_loss -(dist.log_prob(jnp.clip(batch_actions, -1 1e-5, 1 - 1e-5)) * batch[valid][..., -1]).mean() * self.config[bc_alpha]这行代码在告诉 Actor: 不管 Q 值怎么说, 你输出的动作最好和 Buffer 里的真实动作 (演示数据) 接近一些。这对机器人任务极其关键, 因为它防止了机器人在训练初期因为乱甩而撞坏硬件。关于 High UTD:jax.jit def batch_update(self, batch): agent, infos jax.lax.scan(self._update, self, batch)这里使用了 JAX 的 scan 原语。这比 Python 的 for 循环快得多, 它能把 20 次更新编译成一个高效的 GPU 算子。演示数据Prior DataSERL 所基于的 RLPD 算法中作者发现最简单、最有效的办法就是一视同仁比如从 Replay Buffer自己跑的数据取 128 个。从 Demo Buffer演示数据取 128 个。凑成一个 256 的 Batch直接丢进 Loss 函数。意义不需要复杂的权重计算这种简单的对半开采样就能极大地提升效率。看未来与看现在的逻辑rlpd.py 中计算 next_q 用的是 target_critic, 而计算 actor_loss 时用的是 critic (当前网络)。Target Critic (看未来): 用于计算r γ Q_{target}。由于 Q_{target} 更新得很慢 (Soft Update), 它提供了一个稳定的地基, 防止 Q 值计算产生正反馈螺旋 (即自己把自己估高)。Critic (看现在): 用于 Actor 的更新。Actor 问: 我现在的动作好吗? 。由于 Critic 正在被最快地训练, 它能给 Actor 提供最及时的反馈。矛盾解决: 这就是评估要稳 (Target), 改进要快 (Current)的权衡。0x03 BC在 SERL 的复现中, 通常的步骤是:收集 20 个 Demo。跑 bc.py 进行预训练, 让机器人学会手往哪放。跑 rlpd.py 进行正式训练, 利用 BC 训练好的模型作为起点, 通过 50/50 采样快速进化。在 SERL 的完整流程中bc.py 的代码非常关键, 它揭示了 SERL 系统中 Behavioral Cloning (BC) 环节是如何运作的。3.1 BCAgentBCAgent 负责预训练冷启动是纯监督学习的行为克隆实现架构最为简洁。SERL先用 BC 模仿 Demo让机器人学会手往哪放再开启 RL 寻找怎么抓取。3.1.1 BCAgent 在SERL中的作用冷启动为RL算法提供初始策略演示数据利用从专家演示中学习安全基线在RL训练初期提供安全策略核心优势BCAgent 的简洁性使其成为从演示到强化学习的理想桥梁通过监督学习快速获得可用的策略然后可以在此基础上进行RL微调。BCAgent 的特性如下特性BCAgent说明网络数量仅1个Policy网络无Critic无Temperaturetanh_squashFalse不使用tanh压缩输出分布MultivariateNormalDiag标准高斯分布训练目标最小化MSE 负对数似然监督学习3.1.2 BCAgent 核心组件唯一的 Policy 网络network_kwargs[activate_final] True networks { actor: Policy( encoder_def, # 视觉编码器 MLP(**network_kwargs), # 默认 [256, 256] action_dimactions.shape[-1], tanh_squash_distributionFalse, # 关键差异 ) }组件输入网络结构输出特点Policy图像观测编码器MLP[256,256]动作分布(μ,σ)纯监督学习3.1.3 BCAgent 编码器架构small 编码器encoders { image_key: SmallEncoder( features(32, 64, 128, 256), kernel_sizes(3, 3, 3, 3), strides(2, 2, 2, 2), paddingVALID, pool_methodavg, bottleneck_dim256, spatial_block_size8, ) }resnet 编码器encoders { image_key: resnetv1_configs[resnetv1-10]( pooling_methodspatial_learned_embeddings, num_spatial_blocks8, bottleneck_dim256, ) }resnet-pretrained 编码器pretrained_encoder resnetv1_configs[resnetv1-10-frozen]( pre_poolingTrue, ) encoders { image_key: PreTrainedResNetEncoder( pooling_methodspatial_learned_embeddings, num_spatial_blocks8, bottleneck_dim256, pretrained_encoderpretrained_encoder, ) }3.1.4 BCAgent 损失函数def loss_fn(params, rng): # 前向传播 dist self.state.apply_fn( {params: params}, batch[observations], temperature1.0, trainTrue, rngs{dropout: key}, nameactor, ) pi_actions dist.mode() # 预测动作 log_probs dist.log_prob(batch[actions]) # 对数概率 # 多重损失 mse ((pi_actions - batch[actions]) ** 2).sum(-1) # MSE损失 actor_loss -(log_probs).mean() # 负对数似然 return actor_loss, { actor_loss: actor_loss, mse: mse.mean(), }3.1.5 BCAgent 动作采样def sample_actions(self, observations, seedNone, temperature1.0, argmaxFalse): dist self.state.apply_fn( {params: self.state.params}, observations, temperaturetemperature, nameactor, ) if argmax: actions dist.mode() # 确定性采样 else: actions dist.sample(seedseed) # 随机采样 return actions3.2 bc.py 流程图BC Agent (模仿学习) 核心流程图如下。BC 关键组件说明ResNet-10: SERL 的默认视觉骨架用于从原始像素中提取物理特征。Random Crop: 极其重要的 Trick通过对画面进行 ±4 像素的裁剪来模拟环境扰动。TanhNormal: 动作分布模型确保输出的动作符合机械臂的物理范围。Pre-training: BC 在 SERL 中扮演“冷启动”的角色将 RL 的搜索空间缩减到目标附近。3.3 bc.py 技术细节核心逻辑: update 函数不仅仅是回归: 虽然它计算了 mse, 但实际更新用的是 actor_loss -log_probs.mean()。这是一种概率视角下的模仿: 让策略在演示数据给出的状态下, 输出演示动作的概率尽可能大。没有 Critic这里完全没有 Q 网络。BC 纯粹是看着答案抄答案, 不需要奖励信号。视觉处理: 数据增强 (Data Augmentation) data_augmentation_fn这是 DrQ-v2 风格的随机裁剪。为什么重要?: 在机器人任务中, 摄像头画面可能会有轻微抖动。通过对演示图片进行随机裁剪, 可以让策略学会忽略这种位移, 提高鲁棒性。这也是 SERL 能在 20 分钟内学会任务的秘诀之一。网络架构: ResNet-10在 create 方法中, 它定义了三种编码器。SERL 默认推荐的是 resnet:elif encoder_type resnet: encoders { image_key: resnetv1_configs[resnetv1-10](...) }ResNet-10 比传统的 CNN 效果好得多, 因为它能提取更深层的特征, 同时又不像 ResNet-50 那样运算缓慢。为什么会有 mse 却不用它更新?在 update 函数的 loss_fn 中, 作者计算了 mse ((pi_actions - batch[actions]) ** 2).sum(-1), 但返回值的第一项 (真正的 Loss) 是 actor_loss -(log_probs).mean()。原因: MSE 只关心均值对不对, 而 log_prob 关心的是整个概率分布。如果人类演示同一个动作时有微小的偏差, log_prob 能更好地捕获这种容错性。EncodingWrapper 的作用这个包装器能把视觉图像和机械臂自身的状态 (关节角度、末端坐标) 揉在一起。这意味着机器人不仅知道自己看到了什么, 还知道自己现在手在哪。encoder_def EncodingWrapper(..., use_propriouse_proprio, enable_stackingTrue, ...)冷启动与热切换如果复现SERL一般会把 bc.py 练出来的模型会作为 RLPDAgent.create 时的初始参数 (或权重)。这相当于把原本需要几百万次尝试才能学会的动作, 压缩成了几千步的模仿。bc_loss 的数据生效为paddingVALID。这是因为不能对在线数据做 BC Loss在线数据 (智能体自己乱跑出来的) 非常乱。物理后果: 如果强迫智能体去模仿这些乱七八糟的动作, 就像是让一个正在学走路的孩子去模仿自己摔跤的动作。这会导致策略陷入低水平的循环。0x04 High UTD 稳定性机制High UTD 的意义它强迫神经网络在极短的时间内吃透每一张图片。4.1 High UTD把每条真机样本反复研读4.1.1 理解 UTDUpdate-to-Data RatioUTDUpdate-to-Data Ratio表示每采集一条环境数据算法进行多少次梯度更新。传统 RL 常用 UTD1采一步训一步。SERL / RLPD 使用更高 UTD通常为 20 甚至更高采集一条昂贵的真机数据后learner 会多次从 buffer 中采样并更新网络。4.1.2 为什么 SERL 需要高 UTD我们可以把 High UTD 理解成真机数据太贵所以每一帧都要反复研读不能看一遍就扔。没有 UTD 的后果普通的 SAC 每采样一个数据才更新一次。对于机器人这种高维度ResNet 图像且数据量极小只有 2.5 小时数据的任务收敛太慢你可能需要练上 10 天半个月。不稳定性由于视觉特征ResNet需要海量更新才能稳定如果更新频率太低视觉头会一直处于模糊状态无法提取有效的位姿信息。4.1.3 极致的采样效率High UTD 将数据的价值榨取到了极致20x 的复习强度 利用 High UTDUpdate-to-Data策略机器人每在现实中走一步后台 Learner 就会对现有数据进行 20 到 40 次的高频更新。REDQ 保驾护航 为了防止这种高强度学习产生幻觉系统利用 10–20 个 Critic 组成 陪审团Ensemble通过取最小值的方式压制过估计偏差。成果数据压榨逻辑真机采集的数据中隐藏着极其细微的物理交互特征如手爪与工件的摩擦。通过 High UTD模型被迫对有限的样本进行深度研读从而在极短的物理时间内捕捉到成功的信号。这让原本需要几周的训练过程被压缩到了喝杯咖啡的时间。工程代价High UTD 对 Learner 的算力提出了严苛要求。这要求 JAX 必须在毫秒级时间内完成多轮反向传播以确保学习速度始终领先于采样速度。4.1.4 实现partial(jax.jit, static_argnames(utd_ratio, pmap_axis)) def update_high_utd( self, batch: Batch, *, utd_ratio: int, pmap_axis: Optional[str] None, ) - Tuple[SACAgent, dict]: Fast JITted high-UTD version of .update. Splits the batch into minibatches, performs utd_ratio critic (and target) updates, and then one actor/temperature update. Batch dimension must be divisible by utd_ratio. batch_size batch[rewards].shape[0] assert ( batch_size % utd_ratio 0 ), fBatch size {batch_size} must be divisible by UTD ratio {utd_ratio} minibatch_size batch_size // utd_ratio chex.assert_tree_shape_prefix(batch, (batch_size,)) def scan_body(carry: Tuple[SACAgent], data: Tuple[Batch]): (agent,) carry (minibatch,) data agent, info agent.update( minibatch, pmap_axispmap_axis, networks_to_updatefrozenset({critic}) ) return (agent,), info def make_minibatch(data: jnp.ndarray): return jnp.reshape(data, (utd_ratio, minibatch_size) data.shape[1:]) minibatches jax.tree_map(make_minibatch, batch) (agent,), critic_infos jax.lax.scan(scan_body, (self,), (minibatches,)) critic_infos jax.tree_map(lambda x: jnp.mean(x, axis0), critic_infos) del critic_infos[actor] del critic_infos[temperature] # Take one gradient descent step on the actor and temperature agent, actor_temp_infos agent.update( batch, pmap_axispmap_axis, networks_to_updatefrozenset({actor, temperature}), ) del actor_temp_infos[critic] infos {**critic_infos, **actor_temp_infos} return agent, infos4.1.5 UTD 降为 1 的后果从特训班降级为自习室结论如果把 UTD 降为 1效果会大幅变差甚至完全学不会。把 cta_ratioUTD 比率从 20 降到 1导致效果变差的原理主要有三点导师还没谱学生瞎改Critic Lag在 SAC 中Actor学生的更新是基于 Critic导师给出的 Q 值梯度的。UTD 20每跑一步Critic 都要刷题 20 次。这让 Critic 能迅速消化新产生的数据把 Q 值身价算得非常准。当 Actor 来问我该怎么改时Critic 给出的方向是极其精准的。UTD 1Critic 只能练一次就给建议。对于精密插件任务由于视觉特征像素极其复杂Q 值需要海量的更新才能捕捉到插头对准插座那一瞬间的剧烈价值波动。此时 Critic 的评估可能还是模糊的。后果Actor 顺着错误的梯度方向去改只会越练越废。视觉特征提取的滞后Encoder TrainingSERL 使用的是从像素开始的端到端学习。原理ResNet 需要大量的反向传播才能从杂乱的背景中认出插孔。对比20 倍的更新频率意味着视觉编码器Encoder的学习速度提高了 20 倍。如果 UTD1你的机器人可能练了一整天ResNet 还没看清物体的轮廓。数据压榨率Sample Reuse机器人数据是昂贵的需要电机动需要时间。UTD 20每一条真实的物理轨迹都会被拿出来反复揉搓 20 遍。UTD 1这条数据用一次就扔了就像富二代在浪费极其稀有的成功样本。后果在数据极其稀疏的精密操作中UTD1 会导致机器人根本无法在有限的 1 小时训练内通过自发尝试撞到正确答案。4.1.6 High UTD 的副作用与应对但 High UTD 也有副作用。对同一批数据反复训练critic 容易过拟合和过估计最终让策略崩溃。因此SERL 还需要配套的稳定性机制。4.2 稳定性机制High UTD 是发动机但发动机太猛就需要刹车系统。SERL 通过多种机制的协同实现了在极高样本效率下的稳定训练。整体稳定性保障Critic Ensemble / REDQ多个 critic 像陪审团随机子采样取最小值压制 High UTD 带来的估值爆炸Critic LayerNorm让高频更新不至于数值失控支持更高 UTD ratiosSoft Update让目标网络缓慢跟随维持 Bellman 目标平稳保证策略更新平滑RLPD 50/50 采样demo 数据作为锚点防止策略偏离专家分布DrQ 数据增强random_crop 提供最重要的视觉正则化Actor encoder stop-gradient防止 actor loss 破坏视觉表征这些机制不是孤立工作的而是协同配合。SERL 的工程价值在于不是单独实现某个技巧而是把一整套相互配合的稳定性机制整合起来使得高 UTD 这种激进的训练策略能够在真机上稳定运行形成一套可工作的系统。SAC 的巧妙之处恰恰在于它如何利用不确定性来获得最终的稳定。我们接下来选择部分机制进行解读。0x05 Layer Normalization论文中提到regularizing the critic with layer normalization allows for higher UTD ratios and thus more efficient training。也就是说SERL 并不是单纯把 UTD 拉高而是通过 critic 正则化让高频更新不至于数值失控。即为了抗住 20 倍的更新强度而不崩盘SERL 在 Critic 网络中引入了LayerNorm。这在传统 SAC 中是不常见的但在高 UTD 的 RLPD 算法中至关重要。5.1 当前层归一化实现分析从 MLP可以看到已有的层归一化支持class MLP(nn.Module): use_layer_norm: bool False # 层归归一化开关 nn.compact def __call__(self, x: jnp.ndarray, train: bool False) - jnp.ndarray: for i, size in enumerate(self.hidden_dims): x nn.Dense(size, kernel_initdefault_init())(x) # 线性变换 if i 1 len(self.hidden_dims) or self.activate_final: # 正则化层可选 if self.dropout_rate is not None and self.dropout_rate 0: x nn.Dropout(rateself.dropout_rate)(x, deterministicnot train) if self.use_layer_norm: # 关键层归一化应用 x nn.LayerNorm()(x) # 标准化层输出 x activations(x) # 激活函数 return x5.2 层归一化的技术细节和优势Critic 网络的特点输入方差大观测编码和动作拼接导致输入分布不稳定梯度爆炸风险深度网络容易出现梯度问题Ensemble 训练多个 Critic 网络需要稳定的训练动态层归一化的具体好处稳定训练减少内部协变量偏移提高学习率可以使用更大的学习率加速收敛减少训练震荡改善泛化对输入扰动更鲁棒5.3 实现方案SACAgent 创建时启用层归一化.critic_network_kwargs{ activations: nn.tanh, use_layer_norm: True, hidden_dims: [256, 256], }, policy_network_kwargs{ activations: nn.tanh, use_layer_norm: True, hidden_dims: [256, 256], },针对 DrQAgent 的实现critic_network_kwargs{ activations: nn.tanh, use_layer_norm: True, hidden_dims: [256, 256], }, policy_network_kwargs{ activations: nn.tanh, use_layer_norm: True, hidden_dims: [256, 256], },VICE 中的层归一化已实现critic_network_kwargs{ activations: nn.tanh, use_layer_norm: True, hidden_dims: [256, 256], }, vice_network_kwargs{ activations: nn.leaky_relu, use_layer_norm: True, hidden_dims: [ 256, ], dropout_rate: 0.1, }, policy_network_kwargs{ activations: nn.tanh, use_layer_norm: True, hidden_dims: [256, 256], },5.4 总结对 Critic 进行层归一化正则化的关键是配置启用在critic_network_kwargs中设置use_layer_norm: True正则化组合配合 Dropout 和权重衰减获得最佳效果超参数调整层归一化后可以使用更大的学习率针对性优化根据任务类型视觉/状态调整正则化强度性能监控添加统计信息验证层归一化的实际效果在 SERL 框架中这种实现方式既保持了代码的简洁性又充分利用了 Flax/JAX 的模块化优势是提高 Critic 网络训练稳定性和性能的有效手段。0x06 Soft Update REDQSoft Update 让目标网络始终缓慢追踪当前 Q 值保持贝尔曼目标的平稳性。在 REDQ 的高 UTD 场景下尤为重要。6.1 Soft Update 的力量在机器人控制中动作的连续性决定了硬件的寿命。SERL 坚持使用 Soft Update软更新维护目标网络平滑公式θ(target) τ θ_online (1−τ) θ{target}。其中 τ 通常设为极其微小的 0.005。硬件意义与直接拷贝权重的 Hard Update 不同Soft Update 让目标值Target以一种近乎流体的方式缓慢漂移。这反映到机器人身上就是动作的进化是渐进的不会因为模型权重的突跳导致机械臂产生瞬时的冲击电流或抖动。Soft update的核心实现如下def target_update(self, tau: float) - JaxRLTrainState: Performs an update of the target params via polyak averaging. The new target params are given by: new_target_params tau * params (1 - tau) * target_params new_target_params jax.tree_map( lambda p, tp: p * tau tp * (1 - tau), self.params, self.target_params ) return self.replace(target_paramsnew_target_params)这个方法在 SACAgent.update 中被调用# Update target network (if requested) if critic in networks_to_update: new_state new_state.target_update(self.config[soft_target_update_rate])原理分析Soft Update采用Polyak averaging的方式缓慢更新目标网络。这种方法的核心思想是让目标网络以平滑的方式跟踪主网络而不是周期性地完全复制。这种平滑跟踪有助于减少训练过程中的方差提高算法的稳定性防止因目标网络剧烈变化导致的训练震荡6.2 REDQCritic Ensemble 抑制过估计REDQ 模式支持把 Q 网络增加到 10 个以上并从中随机抽 2 个来计算 Target。这是另一种对抗高估问题的强力方法。原生2 个 CriticSERL10–20 个 Critic作用支持 High UTDUTD20。如果没有这么多 Critic 压阵SAC 会在疯狂更新中产生严重的数值爆炸。6.2.1 为什么需要 Critic EnsembleHigh UTD 虽然能加速学习但会带来致命副作用Q 值过估计Overestimation Bias。模型会因为反复研读少量样本而变得极端自信单Q网络容易高估未见过的状态一动作对的价值最终导致策略崩溃。SERL 引入了 REDQRandomized Ensembled Double Q-learning风格的机制来解决这个问题。我们可以将其理解为一种陪审团机制Critic Ensemble陪审团同时训练 10 到 20 个独立运行的 Critic 网络随机子集采样Randomized Subset在计算目标 Q 值时并不看所有人的意见而是随机抽取 2 个 Critic取最小值In-sample Min在抽出的子集中取分数的最小值min操作天然抑制0OD区域的过高估计通俗地说如果十个裁判里随机抽出的几个裁判中有一个觉得这个动作危险那我们就保守一点。这种悲观主义巧妙地抵消了 High UTD 带来的狂热乐观使训练在极高强度下依然稳如磐石。6.2.2 REDQ论文算法算法如下训练时的行为随机采样从 N 个默认 N10Critic 网络中随机无放回地选取 M 个默认 M2子集索引。前向传播仅将这 M 个网络的参数用于计算目标状态动作值 Q(s′,a′)Q(s′,a′)。取最小值对这 M 个输出值取最小值作为 Target Q 值的一部分即 mini∈subsetQi(s′,a′)。损失计算虽然 Target 只用了 M 个网络但在计算 Critic 损失时所有 N 个网络都会根据同一个 Target 进行梯度更新以维持集成多样性 。这种设计既保持了ensemble的容量优势又通过子采样降低了计算成本和过拟合风险。计算效率只计算 K 个网络的前向传播而非全部 N 个正则化效果随机子采样引入额外噪声提高泛化能力过拟合缓解避免始终使用相同的最好网络REDQ 论文证明min(2 from 10)的效果接近min(10)但计算量减少 5 倍。6.2.3 SERL 的实现架构细节# drq.py:124-125 和 launcher.py:165-166 critic_ensemble_size10, # 10 个独立 Q 网络 critic_subsample_size2, # 计算 target 时只随机选 2 个具体实现def critic_loss_fn(self, batch, params: Params, rng: PRNGKey): # ...前期准备代码... # 1. 计算所有ensemble成员的Q值 target_next_qs self.forward_target_critic( batch[next_observations], next_actions, rngrng, ) # shape: (critic_ensemble_size, batch_size) # 2. 如果配置了子采样则随机选择指定数量的网络 if self.config[critic_subsample_size] is not None: rng, subsample_key jax.random.split(rng) subsample_idcs jax.random.randint( subsample_key, (self.config[critic_subsample_size],), # 通常是2 0, self.config[critic_ensemble_size], # 通常是10 ) target_next_qs target_next_qs[subsample_idcs] # 只保留选中的网络 # 3. 在(子采样后的)ensemble成员中取最小值 target_next_min_q target_next_qs.min(axis0) # shape: (batch_size,) # ...后续使用target_next_min_q计算TD目标...与论文的区别SERL 先计算所有10个网络的Q值然后随机选择2个网络进行子采样最后在这2个中选择最小值。详细执行流程前向传播所有ensemble首先计算所有10个critic网络对(next_state, next_action)的Q值随机子采样从10个网络中随机选择2个网络的索引subsample_idcs提取子样本target_next_qs[subsample_idcs]只保留这2个网络的Q值取最小值在这2个网络的Q值中取最小值作为bootstrapping目标REDQ为什么这样设计前向传播成本低注释中明确提到Evaluate next Qs for all ensemble members (cheap because were only doing the forward pass)因为没有梯度计算计算所有10个网络的代价相对较小子采样减少复杂度后续操作梯度计算等只在2个网络上进行大大减少了计算复杂度随机性的好处每次更新随机选择不同的网络组合相当于隐式的bagging增加了训练的随机性和泛化能力过拟合缓解通过随机子采样和取最小值的方式能够有效缓解critic过拟合问题这种设计在保持计算效率的同时充分利用了ensemble的多样性是REDQ算法的核心创新点之一。6.3 时间戳对齐代码中完全没有观测 - 动作的时间戳对齐机制。但这是刻意的设计取舍实际的时序设计step(action) 被调用 ├─1. 计算目标位姿 安全裁剪 ~0.1ms ├─2. _send_gripper_command() ~600ms (夹爪动作时) ├─3. _send_pos_command() ~5ms (HTTP POST) ├─4. time.sleep(1/hz - elapsed) 补齐到 100ms 控制周期 ├─5. _update_curipos() ~5ms (HTTP POST /getstate) └─6. _get_obs() ├ get_im() ~10~30ms (从 VideoCapture 队列取最新) └ 组装 state (用步骤5的数据)为什么不做精确对齐相对偏差可接受10Hz 控制周期 100ms 间隔相机 30fps 33ms 间隔最坏情况图像延迟 33ms相对控制周期只有 1/3 偏差硬件限制精确对齐需要硬件触发同步当前架构不支持算法容错性SAC/DrQ 学习的是随机策略天然对观测噪声有一定容忍度SERL 在系统设计上做出了取舍专注于整体架构的简洁性和可维护性而非追求极致的硬件同步精度。0x07 SAC vs RLPD7.1 相同之处RLPD 的 Critic 更新和 SAC 没有区别。在 SAC 中Target r γ min(Q₁, Q₂)(s, π(s))在 RLPD 中Target r γ min(Q₁, Q₂)(s, π(s))对于 Batch 里的每一个样本无论是来自 Demo 还是 OnlineCritic 都在做同一件事Loss (Q(s, a) - Target)²其中 Target r_{vice} γ Q(s, a)。注意这里的 a 就是当时实际发生的那个动作。如果是 Online 数据a 是机器人自己做的。如果是 Demo 数据a 是人做的。Critic 并不在乎这个动作是谁做的它只负责给这个动作-状态对打分。7.2 不同之处这里的特殊处理不在公式里而是在数据的质量上。RLPD 并没有给 Critic 写一个针对 Demo 的特殊公式它的特殊在于强行喂给 Critic 50% 的高质量样本。秘密不在怎么算Loss Function而在喂什么Batch Composition。Online 样本机器人的动作可能是乱挥a_{random}它对应的 r_{vice} 大概率是 0。Demo 样本人类的动作是精准的a_{demo}它对应的 r_{vice} 大概率是 1或者接近 1 的高分。所有数据无论谁做的都必须经过 VICE 的安检。VICE 说它是成功它才是成功。当 VICE裁判坏了始终给 0 分时会发生什么不公平的待遇原本该拿 1 分的 a_{demo} 现在只能拿 0 分。在线试错 被 VICE 判为 0。人类演示 也被 VICE 判为 0。后果Critic 学到的是不管是人类做的那个精妙动作还是我刚才那个乱挥的动作身价全都是 0。Critic 只能接受着这个事实 —— 这个世界没有奖赏做什么都是徒劳。Q.网络随之萎缩到 0。7.3 通俗解释想象你和机器人都在学数学普通 SAC像自学机器人每天自己做 256 道题。因为是新手错题率 99%。机器人看着这些错题r0吃力地总结经验。弱点因为从来没见过正确解法r1它可能要练 1 万年才能偶尔撞对一次。RLPD像有标准答案的刷题机器人每天做 128 道自己的题Online另外读 128 道满分卷子Demo。它计算 Critic Loss 时虽然公式一样但那 128 道满分卷子带来的 Target 里的 r 是 1。强项它时刻被提醒什么是正确的。Critic 会迅速学到哦人类做的那些动作才是值钱的我自己瞎搞的那些不值钱。