用深度强化学习来玩Chrome小恐龙快跑

目录

实机演示

代码实现


实机演示

用深度强化学习来玩Chrome小恐龙快跑

代码实现

python 复制代码
import os
import cv2
from pygame import RLEACCEL
from pygame.image import load
from pygame.sprite import Sprite, Group, collide_mask
from pygame import Rect, init, time, display, mixer, transform, Surface
from pygame.surfarray import array3d
import torch
from random import randrange, choice
import numpy as np

mixer.pre_init(44100, -16, 2, 2048)
init()

scr_size = (width, height) = (600, 150)
FPS = 60
gravity = 0.6

black = (0, 0, 0)
white = (255, 255, 255)
background_col = (235, 235, 235)

high_score = 0

screen = display.set_mode(scr_size)
clock = time.Clock()
display.set_caption("T-Rex Rush")


def load_image(
        name,
        sizex=-1,
        sizey=-1,
        colorkey=None,
):
    fullname = os.path.join("assets/sprites", name)
    image = load(fullname)
    image = image.convert()
    if colorkey is not None:
        if colorkey is -1:
            colorkey = image.get_at((0, 0))
        image.set_colorkey(colorkey, RLEACCEL)

    if sizex != -1 or sizey != -1:
        image = transform.scale(image, (sizex, sizey))

    return (image, image.get_rect())


def load_sprite_sheet(
        sheetname,
        nx,
        ny,
        scalex=-1,
        scaley=-1,
        colorkey=None,
):
    fullname = os.path.join("assets/sprites", sheetname)
    sheet = load(fullname)
    sheet = sheet.convert()

    sheet_rect = sheet.get_rect()

    sprites = []

    sizey = sheet_rect.height / ny
    if isinstance(nx, int):
        sizex = sheet_rect.width / nx
        for i in range(0, ny):
            for j in range(0, nx):
                rect = Rect((j * sizex, i * sizey, sizex, sizey))
                image = Surface(rect.size)
                image = image.convert()
                image.blit(sheet, (0, 0), rect)

                if colorkey is not None:
                    if colorkey is -1:
                        colorkey = image.get_at((0, 0))
                    image.set_colorkey(colorkey, RLEACCEL)

                if scalex != -1 or scaley != -1:
                    image = transform.scale(image, (scalex, scaley))

                sprites.append(image)

    else:  #list
        sizex_ls = [sheet_rect.width / i_nx for i_nx in nx]
        for i in range(0, ny):
            for i_nx, sizex, i_scalex in zip(nx, sizex_ls, scalex):
                for j in range(0, i_nx):
                    rect = Rect((j * sizex, i * sizey, sizex, sizey))
                    image = Surface(rect.size)
                    image = image.convert()
                    image.blit(sheet, (0, 0), rect)

                    if colorkey is not None:
                        if colorkey is -1:
                            colorkey = image.get_at((0, 0))
                        image.set_colorkey(colorkey, RLEACCEL)

                    if i_scalex != -1 or scaley != -1:
                        image = transform.scale(image, (i_scalex, scaley))

                    sprites.append(image)

    sprite_rect = sprites[0].get_rect()

    return sprites, sprite_rect


def extractDigits(number):
    if number > -1:
        digits = []
        i = 0
        while (number / 10 != 0):
            digits.append(number % 10)
            number = int(number / 10)

        digits.append(number % 10)
        for i in range(len(digits), 5):
            digits.append(0)
        digits.reverse()
        return digits


def pre_processing(image, w=84, h=84):
    image = image[:300, :, :]
    # cv2.imwrite("ori.jpg", image)
    image = cv2.cvtColor(cv2.resize(image, (w, h)), cv2.COLOR_BGR2GRAY)
    # cv2.imwrite("color.jpg", image)
    _, image = cv2.threshold(image, 127, 255, cv2.THRESH_BINARY)
    # cv2.imwrite("bw.jpg", image)

    return image[None, :, :].astype(np.float32)


class Dino():
    def __init__(self, sizex=-1, sizey=-1):
        self.images, self.rect = load_sprite_sheet("dino.png", 5, 1, sizex, sizey, -1)
        self.images1, self.rect1 = load_sprite_sheet("dino_ducking.png", 2, 1, 59, sizey, -1)
        self.rect.bottom = int(0.98 * height)
        self.rect.left = width / 15
        self.image = self.images[0]
        self.index = 0
        self.counter = 0
        self.score = 0
        self.isJumping = False
        self.isDead = False
        self.isDucking = False
        self.isBlinking = False
        self.movement = [0, 0]
        self.jumpSpeed = 11.5

        self.stand_pos_width = self.rect.width
        self.duck_pos_width = self.rect1.width

    def draw(self):
        screen.blit(self.image, self.rect)

    def checkbounds(self):
        if self.rect.bottom > int(0.98 * height):
            self.rect.bottom = int(0.98 * height)
            self.isJumping = False

    def update(self):
        if self.isJumping:
            self.movement[1] = self.movement[1] + gravity

        if self.isJumping:
            self.index = 0
        elif self.isBlinking:
            if self.index == 0:
                if self.counter % 400 == 399:
                    self.index = (self.index + 1) % 2
            else:
                if self.counter % 20 == 19:
                    self.index = (self.index + 1) % 2

        elif self.isDucking:
            if self.counter % 5 == 0:
                self.index = (self.index + 1) % 2
        else:
            if self.counter % 5 == 0:
                self.index = (self.index + 1) % 2 + 2

        if self.isDead:
            self.index = 4

        if not self.isDucking:
            self.image = self.images[self.index]
            self.rect.width = self.stand_pos_width
        else:
            self.image = self.images1[(self.index) % 2]
            self.rect.width = self.duck_pos_width

        self.rect = self.rect.move(self.movement)
        self.checkbounds()

        if not self.isDead and self.counter % 7 == 6 and self.isBlinking == False:
            self.score += 1

        self.counter = (self.counter + 1)


class Cactus(Sprite):
    def __init__(self, speed=5, sizex=-1, sizey=-1):
        Sprite.__init__(self, self.containers)
        self.images, self.rect = load_sprite_sheet("cacti-small.png", [2, 3, 6], 1, sizex, sizey, -1)
        self.rect.bottom = int(0.98 * height)
        self.rect.left = width + self.rect.width
        self.image = self.images[randrange(0, 11)]
        self.movement = [-1 * speed, 0]

    def draw(self):
        screen.blit(self.image, self.rect)

    def update(self):
        self.rect = self.rect.move(self.movement)

        if self.rect.right < 0:
            self.kill()


class Ptera(Sprite):
    def __init__(self, speed=5, sizex=-1, sizey=-1):
        Sprite.__init__(self, self.containers)
        self.images, self.rect = load_sprite_sheet("ptera.png", 2, 1, sizex, sizey, -1)
        self.ptera_height = [height * 0.82, height * 0.75, height * 0.60, height * 0.48]
        self.rect.centery = self.ptera_height[randrange(0, 4)]
        self.rect.left = width + self.rect.width
        self.image = self.images[0]
        self.movement = [-1 * speed, 0]
        self.index = 0
        self.counter = 0

    def draw(self):
        screen.blit(self.image, self.rect)

    def update(self):
        if self.counter % 10 == 0:
            self.index = (self.index + 1) % 2
        self.image = self.images[self.index]
        self.rect = self.rect.move(self.movement)
        self.counter = (self.counter + 1)
        if self.rect.right < 0:
            self.kill()


class Ground():
    def __init__(self, speed=-5):
        self.image, self.rect = load_image("ground.png", -1, -1, -1)
        self.image1, self.rect1 = load_image("ground.png", -1, -1, -1)
        self.rect.bottom = height
        self.rect1.bottom = height
        self.rect1.left = self.rect.right
        self.speed = speed

    def draw(self):
        screen.blit(self.image, self.rect)
        screen.blit(self.image1, self.rect1)

    def update(self):
        self.rect.left += self.speed
        self.rect1.left += self.speed

        if self.rect.right < 0:
            self.rect.left = self.rect1.right

        if self.rect1.right < 0:
            self.rect1.left = self.rect.right


class Cloud(Sprite):
    def __init__(self, x, y):
        Sprite.__init__(self, self.containers)
        self.image, self.rect = load_image("cloud.png", int(90 * 30 / 42), 30, -1)
        self.speed = 1
        self.rect.left = x
        self.rect.top = y
        self.movement = [-1 * self.speed, 0]

    def draw(self):
        screen.blit(self.image, self.rect)

    def update(self):
        self.rect = self.rect.move(self.movement)
        if self.rect.right < 0:
            self.kill()


class Scoreboard():
    def __init__(self, x=-1, y=-1):
        self.score = 0
        self.tempimages, self.temprect = load_sprite_sheet("numbers.png", 12, 1, 11, int(11 * 6 / 5), -1)
        self.image = Surface((55, int(11 * 6 / 5)))
        self.rect = self.image.get_rect()
        if x == -1:
            self.rect.left = width * 0.89
        else:
            self.rect.left = x
        if y == -1:
            self.rect.top = height * 0.1
        else:
            self.rect.top = y

    def draw(self):
        screen.blit(self.image, self.rect)

    def update(self, score):
        score_digits = extractDigits(score)
        self.image.fill(background_col)
        if len(score_digits) == 6:
            score_digits = score_digits[1:]
        for s in score_digits:
            self.image.blit(self.tempimages[s], self.temprect)
            self.temprect.left += self.temprect.width
        self.temprect.left = 0


class ChromeDino(object):
    def __init__(self):
        self.gamespeed = 5
        self.gameOver = False
        self.gameQuit = False
        self.playerDino = Dino(44, 47)
        self.new_ground = Ground(-1 * self.gamespeed)
        self.scb = Scoreboard()
        self.highsc = Scoreboard(width * 0.78)
        self.counter = 0

        self.cacti = Group()
        self.pteras = Group()
        self.clouds = Group()
        self.last_obstacle = Group()

        Cactus.containers = self.cacti
        Ptera.containers = self.pteras
        Cloud.containers = self.clouds

        self.retbutton_image, self.retbutton_rect = load_image("replay_button.png", 35, 31, -1)
        self.gameover_image, self.gameover_rect = load_image("game_over.png", 190, 11, -1)

        self.temp_images, self.temp_rect = load_sprite_sheet("numbers.png", 12, 1, 11, int(11 * 6 / 5), -1)
        self.HI_image = Surface((22, int(11 * 6 / 5)))
        self.HI_rect = self.HI_image.get_rect()
        self.HI_image.fill(background_col)
        self.HI_image.blit(self.temp_images[10], self.temp_rect)
        self.temp_rect.left += self.temp_rect.width
        self.HI_image.blit(self.temp_images[11], self.temp_rect)
        self.HI_rect.top = height * 0.1
        self.HI_rect.left = width * 0.73

    def step(self, action, record=False):  # 0: Do nothing. 1: Jump. 2: Duck
        reward = 0.1
        if action == 0:
            reward += 0.01
            self.playerDino.isDucking = False
        elif action == 1:
            self.playerDino.isDucking = False
            if self.playerDino.rect.bottom == int(0.98 * height):
                self.playerDino.isJumping = True
                self.playerDino.movement[1] = -1 * self.playerDino.jumpSpeed

        elif action == 2:
            if not (self.playerDino.isJumping and self.playerDino.isDead) and self.playerDino.rect.bottom == int(
                    0.98 * height):
                self.playerDino.isDucking = True

        for c in self.cacti:
            c.movement[0] = -1 * self.gamespeed
            if collide_mask(self.playerDino, c):
                self.playerDino.isDead = True
                reward = -1
                break
            else:
                if c.rect.right < self.playerDino.rect.left < c.rect.right + self.gamespeed + 1:
                    reward = 1
                    break

        for p in self.pteras:
            p.movement[0] = -1 * self.gamespeed
            if collide_mask(self.playerDino, p):
                self.playerDino.isDead = True
                reward = -1
                break
            else:
                if p.rect.right < self.playerDino.rect.left < p.rect.right + self.gamespeed + 1:
                    reward = 1
                    break

        if len(self.cacti) < 2:
            if len(self.cacti) == 0 and len(self.pteras) == 0:
                self.last_obstacle.empty()
                self.last_obstacle.add(Cactus(self.gamespeed, [60, 40, 20], choice([40, 45, 50])))
            else:
                for l in self.last_obstacle:
                    if l.rect.right < width * 0.7 and randrange(0, 50) == 10:
                        self.last_obstacle.empty()
                        self.last_obstacle.add(Cactus(self.gamespeed, [60, 40, 20], choice([40, 45, 50])))

        # if len(self.pteras) == 0 and randrange(0, 200) == 10 and self.counter > 500:
        if len(self.pteras) == 0 and len(self.cacti) < 2 and randrange(0, 50) == 10 and self.counter > 500:
            for l in self.last_obstacle:
                if l.rect.right < width * 0.8:
                    self.last_obstacle.empty()
                    self.last_obstacle.add(Ptera(self.gamespeed, 46, 40))

        if len(self.clouds) < 5 and randrange(0, 300) == 10:
            Cloud(width, randrange(height / 5, height / 2))

        self.playerDino.update()
        self.cacti.update()
        self.pteras.update()
        self.clouds.update()
        self.new_ground.update()
        self.scb.update(self.playerDino.score)

        state = display.get_surface()
        screen.fill(background_col)
        self.new_ground.draw()
        self.clouds.draw(screen)
        self.scb.draw()
        self.cacti.draw(screen)
        self.pteras.draw(screen)
        self.playerDino.draw()

        display.update()
        clock.tick(FPS)

        if self.playerDino.isDead:
            self.gameOver = True

        self.counter = (self.counter + 1)

        if self.gameOver:
            self.__init__()

        state = array3d(state)
        if record:
            return torch.from_numpy(pre_processing(state)), np.transpose(
                cv2.cvtColor(state, cv2.COLOR_RGB2BGR), (1, 0, 2)), reward, not (reward > 0)
        else:
            return torch.from_numpy(pre_processing(state)), reward, not (reward > 0)
python 复制代码
import torch.nn as nn

class DeepQNetwork(nn.Module):
    def __init__(self):
        super(DeepQNetwork, self).__init__()

        self.conv1 = nn.Sequential(nn.Conv2d(4, 32, kernel_size=8, stride=4), nn.ReLU(inplace=True))
        self.conv2 = nn.Sequential(nn.Conv2d(32, 64, kernel_size=4, stride=2), nn.ReLU(inplace=True))
        self.conv3 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, stride=1), nn.ReLU(inplace=True))

        self.fc1 = nn.Sequential(nn.Linear(7 * 7 * 64, 512), nn.ReLU(inplace=True))
        self.fc2 = nn.Linear(512, 3)
        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
                nn.init.uniform_(m.weight, -0.01, 0.01)
                nn.init.constant_(m.bias, 0)

    def forward(self, input):
        output = self.conv1(input)
        output = self.conv2(output)
        output = self.conv3(output)
        output = output.view(output.size(0), -1)
        output = self.fc1(output)
        output = self.fc2(output)

        return output
python 复制代码
import argparse
import torch

from src.model import DeepQNetwork
from src.env import ChromeDino
import cv2


def get_args():
    parser = argparse.ArgumentParser(
        """Implementation of Deep Q Network to play Chrome Dino""")
    parser.add_argument("--saved_path", type=str, default="trained_models")
    parser.add_argument("--fps", type=int, default=60, help="frames per second")
    parser.add_argument("--output", type=str, default="output/chrome_dino.mp4", help="the path to output video")

    args = parser.parse_args()
    return args


def q_test(opt):
    if torch.cuda.is_available():
        torch.cuda.manual_seed(123)
    else:
        torch.manual_seed(123)
    model = DeepQNetwork()
    checkpoint_path = "{}/chrome_dino.pth".format(opt.saved_path)
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint["model_state_dict"])
    model.eval()
    env = ChromeDino()
    state, raw_state, _, _ = env.step(0, True)
    state = torch.cat(tuple(state for _ in range(4)))[None, :, :, :]
    if torch.cuda.is_available():
        model.cuda()
        state = state.cuda()
    out = cv2.VideoWriter(opt.output, cv2.VideoWriter_fourcc(*"MJPG"), opt.fps, (600, 150))
    done = False
    while not done:
        prediction = model(state)[0]
        action = torch.argmax(prediction).item()
        next_state, raw_next_state, reward, done = env.step(action, True)
        out.write(raw_next_state)
        if torch.cuda.is_available():
            next_state = next_state.cuda()
        next_state = torch.cat((state[0, 1:, :, :], next_state))[None, :, :, :]
        state = next_state



if __name__ == "__main__":
    opt = get_args()
    q_test(opt)
python 复制代码
import argparse
import os
from random import random, randint, sample
import pickle
import numpy as np
import torch
import torch.nn as nn

from src.model import DeepQNetwork
from src.env import ChromeDino


def get_args():
    parser = argparse.ArgumentParser(
        """Implementation of Deep Q Network to play Chrome Dino""")
    parser.add_argument("--batch_size", type=int, default=64, help="The number of images per batch")
    parser.add_argument("--optimizer", type=str, choices=["sgd", "adam"], default="adam")
    parser.add_argument("--lr", type=float, default=1e-4)
    parser.add_argument("--gamma", type=float, default=0.99)
    parser.add_argument("--initial_epsilon", type=float, default=0.1)
    parser.add_argument("--final_epsilon", type=float, default=1e-4)
    parser.add_argument("--num_decay_iters", type=float, default=2000000)
    parser.add_argument("--num_iters", type=int, default=2000000)
    parser.add_argument("--replay_memory_size", type=int, default=50000,
                        help="Number of epoches between testing phases")
    parser.add_argument("--saved_folder", type=str, default="trained_models")

    args = parser.parse_args()
    return args


def train(opt):
    if torch.cuda.is_available():
        torch.cuda.manual_seed(123)
    else:
        torch.manual_seed(123)
    model = DeepQNetwork()
    if torch.cuda.is_available():
        model.cuda()
    optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr)
    if not os.path.isdir(opt.saved_folder):
        os.makedirs(opt.saved_folder)
    checkpoint_path = os.path.join(opt.saved_folder, "chrome_dino.pth")
    memory_path = os.path.join(opt.saved_folder, "replay_memory.pkl")
    if os.path.isfile(checkpoint_path):
        checkpoint = torch.load(checkpoint_path)
        iter = checkpoint["iter"] + 1
        model.load_state_dict(checkpoint["model_state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer"])
        print("Load trained model from iteration {}".format(iter))
    else:
        iter = 0
    if os.path.isfile(memory_path):
        with open(memory_path, "rb") as f:
            replay_memory = pickle.load(f)
        print("Load replay memory")
    else:
        replay_memory = []
    criterion = nn.MSELoss()
    env = ChromeDino()
    state, _, _ = env.step(0)
    state = torch.cat(tuple(state for _ in range(4)))[None, :, :, :]
    while iter < opt.num_iters:
        if torch.cuda.is_available():
            prediction = model(state.cuda())[0]
        else:
            prediction = model(state)[0]
        # Exploration or exploitation
        epsilon = opt.final_epsilon + (
                max(opt.num_decay_iters - iter, 0) * (opt.initial_epsilon - opt.final_epsilon) / opt.num_decay_iters)
        u = random()
        random_action = u <= epsilon
        if random_action:
            action = randint(0, 2)
        else:
            action = torch.argmax(prediction).item()

        next_state, reward, done = env.step(action)
        next_state = torch.cat((state[0, 1:, :, :], next_state))[None, :, :, :]
        replay_memory.append([state, action, reward, next_state, done])
        if len(replay_memory) > opt.replay_memory_size:
            del replay_memory[0]
        batch = sample(replay_memory, min(len(replay_memory), opt.batch_size))
        state_batch, action_batch, reward_batch, next_state_batch, done_batch = zip(*batch)

        state_batch = torch.cat(tuple(state for state in state_batch))
        action_batch = torch.from_numpy(
            np.array([[1, 0, 0] if action == 0 else [0, 1, 0] if action == 1 else [0, 0, 1] for action in
                      action_batch], dtype=np.float32))
        reward_batch = torch.from_numpy(np.array(reward_batch, dtype=np.float32)[:, None])
        next_state_batch = torch.cat(tuple(state for state in next_state_batch))

        if torch.cuda.is_available():
            state_batch = state_batch.cuda()
            action_batch = action_batch.cuda()
            reward_batch = reward_batch.cuda()
            next_state_batch = next_state_batch.cuda()
        current_prediction_batch = model(state_batch)
        next_prediction_batch = model(next_state_batch)

        y_batch = torch.cat(
            tuple(reward if done else reward + opt.gamma * torch.max(prediction) for reward, done, prediction in
                  zip(reward_batch, done_batch, next_prediction_batch)))

        q_value = torch.sum(current_prediction_batch * action_batch, dim=1)
        optimizer.zero_grad()
        loss = criterion(q_value, y_batch)
        loss.backward()
        optimizer.step()

        state = next_state
        iter += 1
        print("Iteration: {}/{}, Loss: {:.5f}, Epsilon {:.5f}, Reward: {}".format(
            iter + 1,
            opt.num_iters,
            loss,
            epsilon, reward))
        if (iter + 1) % 50000 == 0:
            checkpoint = {"iter": iter,
                          "model_state_dict": model.state_dict(),
                          "optimizer": optimizer.state_dict()}
            torch.save(checkpoint, checkpoint_path)
            with open(memory_path, "wb") as f:
                pickle.dump(replay_memory, f, protocol=pickle.HIGHEST_PROTOCOL)


if __name__ == "__main__":
    opt = get_args()
    train(opt)
相关推荐
饭饭大王666几秒前
CANN 生态深度整合:使用 `pipeline-runner` 构建高吞吐视频分析流水线
人工智能·音视频
初恋叫萱萱1 分钟前
CANN 生态中的异构调度中枢:深入 `runtime` 项目实现高效任务编排
人工智能
简佐义的博客2 分钟前
生信入门进阶指南:学习顶级实验室多组学整合方案,构建肾脏细胞空间分子图谱
人工智能·学习
白日做梦Q2 分钟前
Anchor-free检测器全解析:CenterNet vs FCOS
python·深度学习·神经网络·目标检测·机器学习
吃杠碰小鸡3 分钟前
高中数学-数列-导数证明
前端·数学·算法
无名修道院3 分钟前
自学AI制作小游戏
人工智能·lora·ai大模型应用开发·小游戏制作
kingwebo'sZone8 分钟前
C#使用Aspose.Words把 word转成图片
前端·c#·word
晚霞的不甘12 分钟前
CANN × ROS 2:为智能机器人打造实时 AI 推理底座
人工智能·神经网络·架构·机器人·开源
互联网Ai好者15 分钟前
MiyoAI数参首发体验——不止于监控,更是你的智能决策参谋
人工智能
island131415 分钟前
CANN HIXL 通信库深度解析:单边点对点数据传输、异步模型与异构设备间显存直接访问
人工智能·深度学习·神经网络