Agents Module
The agents module provides various RL agent implementations for playing UNO.
Base Agent Class
RandomAgent Class
- class src.agents.RandomAgent
Agent that selects random valid actions.
Useful as a baseline for evaluation.
SB3Agent Class
- class src.agents.SB3Agent(model_path, deterministic=True)
Wrapper for Stable-Baselines3 models.
- Parameters:
- model: BaseAlgorithm
The loaded SB3 model.
- select_action(obs, valid_actions)
Select action using the trained model.
- Parameters:
obs – Current observation
valid_actions – Valid actions (used for masking if needed)
- Returns:
Action from model prediction
- Return type:
- reset()
Reset LSTM hidden state.
DQNAgent Class
- class src.agents.DQNAgent(model_path=None, state_size=17, action_size=9)
Deep Q-Network agent.
- Parameters:
model_path – Path to saved model (optional)
state_size – Observation dimension
action_size – Number of actions
- memory: deque
Experience replay buffer.
- select_action(obs, valid_actions)
Select action using epsilon-greedy policy.
- train(batch_size=64)
Train on a batch from replay buffer.
- save(path)
Save model weights.
- load(path)
Load model weights.
Recurrent Agents
For LSTM-based agents, the hidden state persists across timesteps:
class RecurrentAgent(SB3Agent):
"""Agent with LSTM memory."""
def __init__(self, model_path):
super().__init__(model_path)
self.lstm_state = None
def select_action(self, obs, valid_actions):
action, self.lstm_state = self.model.predict(
obs,
state=self.lstm_state,
deterministic=True
)
return action
def reset(self):
self.lstm_state = None
Example Usage
Random Agent
from src.agents import RandomAgent
from src.state_action_reward import UnoEnv
env = UnoEnv()
agent = RandomAgent()
obs, _ = env.reset()
done = False
while not done:
valid = env.get_valid_actions()
action = agent.select_action(obs, valid)
obs, reward, done, _, _ = env.step(action)
SB3 Agent
from src.agents import SB3Agent
from src.state_action_reward import UnoEnv
env = UnoEnv()
agent = SB3Agent("models/selfplay_champion.zip")
obs, _ = env.reset()
agent.reset() # Clear LSTM state
done = False
while not done:
valid = env.get_valid_actions()
action = agent.select_action(obs, valid)
obs, reward, done, _, _ = env.step(action)
print(f"Result: {'Win' if reward > 0 else 'Loss'}")
Comparing Agents
from src.agents import SB3Agent, RandomAgent
from src.state_action_reward import UnoEnv
def evaluate(agent1, agent2, env, num_games=100):
wins = 0
for _ in range(num_games):
obs, _ = env.reset()
agent1.reset()
done = False
while not done:
action = agent1.select_action(obs, env.get_valid_actions())
obs, reward, done, _, _ = env.step(action)
if reward > 0:
wins += 1
return wins / num_games
env = UnoEnv()
champion = SB3Agent("models/selfplay_champion.zip")
random = RandomAgent()
win_rate = evaluate(champion, random, env)
print(f"Champion win rate: {win_rate:.1%}")
Agent Comparison Table
Agent |
Type |
Win Rate |
Notes |
|---|---|---|---|
SB3Agent (Self-Play) |
LSTM |
70%+ |
Best performance |
SB3Agent (RecPPO) |
LSTM |
60% |
Good baseline |
SB3Agent (PPO) |
MLP |
53% |
No memory |
DQNAgent |
MLP |
48% |
Value-based |
RandomAgent |
None |
25% |
Baseline |