"""
file:
train.py
description:
implementation of a deep q-network that teaches an agent to play enduro
from the atari gym. prints logs to console and saves model checkpoints.
url:
https://kyscg.github.io/2025/07/11/dqnenduro
author:
kyscg
"""
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import random
import numpy as np
from PIL import Image
from collections import deque
import ale_py
import gymnasium as gym
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')
GAMMA = 0.99
BATCH_SIZE = 32
BUFFER_CAPACITY = 100000
LEARNING_RATE = 0.00025
EPS_START = 1.0
EPS_END = 0.1
EPS_DECAY = 250000
TARGET_UPDATE = 10000
NUM_EPISODES = 500
INITIAL_COLLECT_STEPS = 50000
FRAME_SKIP = 4
STACK_SIZE = 4
def preprocess_frame(frame):
img = Image.fromarray(frame)
img = img.convert('L')
img = img.resize((84, 84), Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.ANTIALIAS)
processed_frame = np.array(img, dtype=np.uint8)
return processed_frame
class DQN(nn.Module):
def __init__(self, num_actions):
super(DQN, self).__init__()
self.conv1 = nn.Conv2d(STACK_SIZE, 32, kernel_size=8, stride=4)
self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)
self._feature_size = self._get_conv_output_size((STACK_SIZE, 84, 84))
self.fc1 = nn.Linear(self._feature_size, 512)
self.fc2 = nn.Linear(512, num_actions)
def _get_conv_output_size(self, shape):
dummy = torch.zeros(1, *shape)
o = F.relu(self.conv1(dummy))
o = F.relu(self.conv2(o))
o = F.relu(self.conv3(o))
return int(np.prod(o.size()))
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = F.relu(self.conv3(x))
x = x.view(x.size(0), -1)
x = F.relu(self.fc1(x))
return self.fc2(x)
class ReplayBuffer:
def __init__(self, capacity):
self.buffer = deque(maxlen=capacity)
def push(self, state, action, reward, next_state, done):
self.buffer.append((state, action, reward, next_state, done))
def sample(self, batch_size):
transitions = random.sample(self.buffer, batch_size)
states_raw, actions, rewards, next_states_raw, dones = zip(*transitions)
states = torch.tensor(np.array(states_raw), dtype=torch.float32).to(device) / 255.0
next_states = torch.tensor(np.array(next_states_raw), dtype=torch.float32).to(device) / 255.0
actions = torch.tensor(np.array(actions), dtype=torch.long).to(device)
rewards = torch.tensor(np.array(rewards), dtype=torch.float32).to(device)
dones = torch.tensor(np.array(dones), dtype=torch.bool).to(device)
return states, actions, rewards, next_states, dones
def __len__(self):
return len(self.buffer)
class DQNAgent:
def __init__(self, state_shape, action_dim, lr, gamma, eps_start, eps_end, eps_decay, target_update_freq, device):
self.action_dim = action_dim
self.gamma = gamma
self.eps_start = eps_start
self.eps_end = eps_end
self.eps_decay = eps_decay
self.target_update_freq = target_update_freq
self.device = device
self.policy_net = DQN(action_dim).to(device)
self.target_net = DQN(action_dim).to(device)
self.target_net.load_state_dict(self.policy_net.state_dict())
self.target_net.eval()
self.optimizer = optim.RMSprop(self.policy_net.parameters(), lr=lr, alpha=0.95, eps=0.01, weight_decay=0)
self.criterion = nn.SmoothL1Loss()
self.steps_done = 0
def select_action(self, state):
eps_threshold = self.eps_end + (self.eps_start - self.eps_end) * np.exp(-1. * self.steps_done / self.eps_decay)
if random.random() < eps_threshold:
return torch.tensor([random.randrange(self.action_dim)], device=self.device, dtype=torch.long)
else:
with torch.no_grad():
state_tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0).to(self.device) / 255.0
q_values = self.policy_net(state_tensor)
return q_values.argmax(1).view(1, 1)
def optimize_model(self, replay_buffer):
if len(replay_buffer) < BATCH_SIZE:
return
states, actions, rewards, next_states, dones = replay_buffer.sample(BATCH_SIZE)
state_action_values = self.policy_net(states).gather(1, actions.unsqueeze(-1)).squeeze(-1)
with torch.no_grad():
next_state_values = self.target_net(next_states).max(1)[0]
expected_state_action_values = rewards + (self.gamma * next_state_values * (1 - dones.float()))
loss = self.criterion(state_action_values, expected_state_action_values)
self.optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 1.0)
self.optimizer.step()
self.steps_done += 1
if self.steps_done % self.target_update_freq == 0:
self.target_net.load_state_dict(self.policy_net.state_dict())
print(f'Target network updated at step {self.steps_done}')
def train():
env = gym.make('ALE/Enduro-v5')
num_actions = env.action_space.n
print(f'Environment: ALE/Enduro-v5 | Action Space: {num_actions} actions.')
agent = DQNAgent(
state_shape=(STACK_SIZE, 84, 84),
action_dim=num_actions,
lr=LEARNING_RATE,
gamma=GAMMA,
eps_start=EPS_START,
eps_end=EPS_END,
eps_decay=EPS_DECAY,
target_update_freq=TARGET_UPDATE,
device=device
)
replay_buffer = ReplayBuffer(BUFFER_CAPACITY)
total_frames = 0
episode_rewards = []
print("--- Starting Initial Experience Collection (Random Actions) ---")
print(f"Collecting {INITIAL_COLLECT_STEPS} steps to fill replay buffer...")
state_raw, _ = env.reset()
state_processed = preprocess_frame(state_raw)
stacked_frames_deque = deque([state_processed] * STACK_SIZE, maxlen=STACK_SIZE)
current_stacked_state = np.array(stacked_frames_deque, dtype=np.uint8)
for i in range(INITIAL_COLLECT_STEPS):
action = env.action_space.sample()
next_frame_raw, reward, terminated, truncated, _ = env.step(action)
done = terminated or truncated
next_frame_preprocessed = preprocess_frame(next_frame_raw)
stacked_frames_deque.append(next_frame_preprocessed)
next_stacked_state = np.array(stacked_frames_deque, dtype=np.uint8)
clipped_reward_sum = np.sign(reward) if reward != 0 else 0
replay_buffer.push(current_stacked_state, action, clipped_reward_sum, next_stacked_state, done)
current_stacked_state = next_stacked_state
total_frames += 1
if done:
state_raw, _ = env.reset()
state_processed = preprocess_frame(state_raw)
stacked_frames_deque = deque([state_processed] * STACK_SIZE, maxlen=STACK_SIZE)
current_stacked_state = np.array(stacked_frames_deque, dtype=np.uint8)
if (i + 1) % 10000 == 0:
print(f"Collected {i + 1}/{INITIAL_COLLECT_STEPS} frames for initial buffer.")
print(f'Initial experience collection finished. Replay buffer size: {len(replay_buffer)}')
print('--- Starting Main Training Loop ---')
for episode in range(NUM_EPISODES):
state_raw, _ = env.reset()
state_processed = preprocess_frame(state_raw)
stacked_frames_deque = deque([state_processed] * STACK_SIZE, maxlen=STACK_SIZE)
current_stacked_state = np.array(stacked_frames_deque, dtype=np.uint8)
episode_reward = 0
done = False
while not done:
action_tensor = agent.select_action(current_stacked_state)
action = action_tensor.item()
reward_sum_over_skip = 0
frame_skip_done = False
for _ in range(FRAME_SKIP):
next_frame_raw, reward, terminated, truncated, _ = env.step(action)
done = terminated or truncated
reward_sum_over_skip += reward
if done:
frame_skip_done = True
break
next_frame_processed = preprocess_frame(next_frame_raw)
stacked_frames_deque.append(next_frame_processed)
next_stacked_state = np.array(stacked_frames_deque, dtype=np.uint8)
clipped_reward_sum = np.sign(reward_sum_over_skip) if reward_sum_over_skip != 0 else 0
replay_buffer.push(current_stacked_state, action, clipped_reward_sum, next_stacked_state, done)
current_stacked_state = next_stacked_state
episode_reward += reward_sum_over_skip
total_frames += 1
agent.steps_done = total_frames
agent.optimize_model(replay_buffer)
if frame_skip_done:
done = True
episode_rewards.append(episode_reward)
avg_reward_last_100 = np.mean(episode_rewards[-100:])
current_epsilon = agent.eps_end + (agent.eps_start - agent.eps_end) * np.exp(-1. * agent.steps_done / agent.eps_decay)
print(f"Episode {episode + 1}/{NUM_EPISODES} | Total Frames: {total_frames} | "
f"Epsilon: {current_epsilon:.4f} | Reward: {episode_reward:.2f} | "
f"Avg 100-episode Reward: {avg_reward_last_100:.2f}")
if (episode + 1) % 50 == 0:
torch.save(agent.policy_net.state_dict(), f"dqn_enduro_episode_{episode + 1}.pth")
print(f"Model checkpoint saved after episode {episode + 1}")
env.close()
print("\n--- Training Finished! ---")
print("Final average reward over last 100 episodes:", np.mean(episode_rewards[-100:]))
torch.save(agent.policy_net.state_dict(), "dqn_enduro_final.pth")
print("Final model saved as dqn_enduro_final.pth")
if __name__ == '__main__':
train()