egpo进行train_egpo训练时,keyvalueError:“replay_sequence_length“

def execution_plan(workers: WorkerSet,

config: TrainerConfigDict) -> LocalIterator[dict]:

if config.get("prioritized_replay"):

prio_args = {

"prioritized_replay_alpha": config["prioritized_replay_alpha"],

"prioritized_replay_beta": config["prioritized_replay_beta"],

"prioritized_replay_eps": config["prioritized_replay_eps"],

}

else:

prio_args = {}

复制代码
local_replay_buffer = LocalReplayBuffer(
    num_shards=1,
    learning_starts=config["learning_starts"],
    buffer_size=config["buffer_size"],
    replay_batch_size=config["train_batch_size"],
    replay_mode=config["multiagent"]["replay_mode"],
    #这一行需要注释掉,如果不注释掉,整个代码就跑不起来,可能是因为ray1.4.1版本没有这个参数
    # replay_sequence_length=config["replay_sequence_length"],
    **prio_args)

rollouts = ParallelRollouts(workers, mode="bulk_sync")

# Update penalty
rollouts = rollouts.for_each(UpdateSaverPenalty(workers))
# We execute the following steps concurrently:
# (1) Generate rollouts and store them in our local replay buffer. Calling
# next() on store_op drives this.
store_op = rollouts.for_each(StoreToReplayBuffer(local_buffer=local_replay_buffer))

def update_prio(item):
    samples, info_dict = item
    if config.get("prioritized_replay"):
        prio_dict = {}
        for policy_id, info in info_dict.items():
            # TODO(sven): This is currently structured differently for
            #  torch/tf. Clean up these results/info dicts across
            #  policies (note: fixing this in torch_policy.py will
            #  break e.g. DDPPO!).
            td_error = info.get("td_error",
                                info[LEARNER_STATS_KEY].get("td_error"))
            prio_dict[policy_id] = (samples.policy_batches[policy_id]
                                    .data.get("batch_indexes"), td_error)
        local_replay_buffer.update_priorities(prio_dict)
    return info_dict

# (2) Read and train on experiences from the replay buffer. Every batch
# returned from the LocalReplay() iterator is passed to TrainOneStep to
# take a SGD step, and then we decide whether to update the target network.
post_fn = config.get("before_learn_on_batch") or (lambda b, *a: b)
replay_op = Replay(local_buffer=local_replay_buffer) \
    .for_each(lambda x: post_fn(x, workers, config)) \
    .for_each(TrainOneStep(workers)) \
    .for_each(update_prio) \
    .for_each(UpdateTargetNetwork(
    workers, config["target_network_update_freq"]))

# Alternate deterministically between (1) and (2). Only return the output
# of (2) since training metrics are not available until (2) runs.
train_op = Concurrently(
    [store_op, replay_op],
    mode="round_robin",
    output_indexes=[1],
    round_robin_weights=calculate_rr_weights(config))

return StandardMetricsReporting(train_op, workers, config)
相关推荐
测试199841 分钟前
Selenium自动化测试+OCR-获取图片页面小说详解
自动化测试·软件测试·python·selenium·测试工具·ocr·测试用例
闲人编程43 分钟前
使用MLflow跟踪和管理你的机器学习实验
开发语言·人工智能·python·机器学习·ml·codecapsule
panplan.top1 小时前
Tornado + Motor 微服务架构(Docker + 测试 + Kubernetes)
linux·python·docker·微服务·k8s·tornado
看兵马俑的程序员1 小时前
RAG实现-本地PDF内容加载和切片
开发语言·python·pdf
是梦终空2 小时前
计算机毕业设计240—基于python+爬虫+html的微博舆情数据可视化系统(源代码+数据库)
爬虫·python·pandas·课程设计·毕业论文·计算机毕业设计·微博舆情可视化
CodeJourney.2 小时前
Python开发可视化音乐播放器教程(附代码)
数据库·人工智能·python
爱学习的小鱼gogo3 小时前
pyhton 螺旋矩阵(指针-矩阵-中等)含源码(二十六)
python·算法·矩阵·指针·经验·二维数组·逆序
言之。3 小时前
Andrej Karpathy 演讲【PyTorch at Tesla】
人工智能·pytorch·python
赵谨言3 小时前
基于Python楼王争霸劳动竞赛数据处理分析
大数据·开发语言·经验分享·python
智启七月4 小时前
谷歌 Gemini 3.0 正式发布:一键生成 Web OS,编程能力碾压竞品
人工智能·python