"""
file:
infer.py
description:
loads model checkpoints and generates atari enduro gameplay video.
also prints test average of any number of episodes required
url:
https://kyscg.github.io/2025/07/11/dqnenduro
author:
kyscg
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import imageio
import numpy as np
from PIL import Image
from collections import deque
import ale_py
import gymnasium as gym
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
MODEL_PATH = "path to checkpoint"
NUM_EVAL_EPISODES = 7
VIDEO_FILENAME = "path to video"
VIDEO_FPS = 30
FRAME_SKIP = 4
STACK_SIZE = 4
def preprocess_frame(frame):
img = Image.fromarray(frame)
img = img.convert('L')
img = img.resize((84, 84), Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.ANTIALIAS)
processed_frame = np.array(img, dtype=np.uint8)
return processed_frame
class DQN(nn.Module):
def __init__(self, num_actions):
super(DQN, self).__init__()
self.conv1 = nn.Conv2d(STACK_SIZE, 32, kernel_size=8, stride=4)
self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)
self._feature_size = self._get_conv_output_size((STACK_SIZE, 84, 84))
self.fc1 = nn.Linear(self._feature_size, 512)
self.fc2 = nn.Linear(512, num_actions)
def _get_conv_output_size(self, shape):
dummy_input = torch.zeros(1, *shape)
o = F.relu(self.conv1(dummy_input))
o = F.relu(self.conv2(o))
o = F.relu(self.conv3(o))
return int(np.prod(o.size()))
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = F.relu(self.conv3(x))
x = x.view(x.size(0), -1)
x = F.relu(self.fc1(x))
return self.fc2(x)
def generate_gameplay_video():
env = gym.make('ALE/Enduro-v5', render_mode='rgb_array')
num_actions = env.action_space.n
print(f"Environment: ALE/Enduro-v5 | Action Space: {num_actions} actions.")
policy_net = DQN(num_actions).to(device)
try:
policy_net.load_state_dict(torch.load(MODEL_PATH, map_location=device))
policy_net.eval()
print(f"Successfully loaded model from {MODEL_PATH}")
except FileNotFoundError:
print(f"Error: Model file not found at {MODEL_PATH}.")
env.close()
return
except Exception as e:
print(f"Error loading model: {e}")
env.close()
return
frames_to_record = []
total_rewards = []
print(f"Generating gameplay video for {NUM_EVAL_EPISODES} episodes...")
for episode in range(NUM_EVAL_EPISODES):
state_raw, _ = env.reset()
rendered_frame = env.render()
if rendered_frame is not None:
frames_to_record.append(rendered_frame)
state_processed = preprocess_frame(state_raw)
stacked_frames_deque = deque([state_processed] * STACK_SIZE, maxlen=STACK_SIZE)
current_stacked_state = np.array(stacked_frames_deque, dtype=np.uint8)
episode_reward = 0
done = False
while not done:
state_tensor = torch.tensor(current_stacked_state, dtype=torch.float32).unsqueeze(0).to(device) / 255.0
with torch.no_grad():
q_values = policy_net(state_tensor)
action = q_values.argmax(1).item()
reward_sum_over_skip = 0
frame_skip_done = False
for _ in range(FRAME_SKIP):
next_frame_raw, reward, terminated, truncated, _ = env.step(action)
done = terminated or truncated
rendered_frame = env.render()
if rendered_frame is not None:
frames_to_record.append(rendered_frame)
reward_sum_over_skip += reward
if done:
frame_skip_done = True
break
next_frame_processed = preprocess_frame(next_frame_raw)
stacked_frames_deque.append(next_frame_processed)
next_stacked_state = np.array(stacked_frames_deque, dtype=np.uint8)
current_stacked_state = next_stacked_state
episode_reward += reward_sum_over_skip
if frame_skip_done:
done = True
total_rewards.append(episode_reward)
print(f"Episode {episode + 1}/{NUM_EVAL_EPISODES} finished with reward: {episode_reward:.2f}")
env.close()
if frames_to_record:
print(f"Saving video to {VIDEO_FILENAME} with {len(frames_to_record)} frames...")
imageio.mimwrite(VIDEO_FILENAME, frames_to_record, fps=VIDEO_FPS, codec='libx264', quality=9)
print(f"Total rewards for {NUM_EVAL_EPISODES} episodes: {total_rewards}")
print(f"Average reward per episode: {np.mean(total_rewards):.2f}")
else:
print("No frames were recorded. Video not generated.")
if __name__ == '__main__':
generate_gameplay_video()