Agents Module

The agents module provides various RL agent implementations for playing UNO.

Base Agent Class

class src.agents.RLAgent

Abstract base class for all RL agents.

abstract select_action(obs, valid_actions)

Select an action given observation and valid actions.

Parameters:
  • obs (np.ndarray) – Current observation

  • valid_actions (List[int]) – List of valid action indices

Returns:

Selected action index

Return type:

int

reset()

Reset agent state (for stateful agents like LSTM).

RandomAgent Class

class src.agents.RandomAgent

Agent that selects random valid actions.

Useful as a baseline for evaluation.

select_action(obs, valid_actions)

Select a random valid action.

Parameters:
  • obs – Current observation (ignored)

  • valid_actions – List of valid action indices

Returns:

Random action from valid_actions

Return type:

int

SB3Agent Class

class src.agents.SB3Agent(model_path, deterministic=True)

Wrapper for Stable-Baselines3 models.

Parameters:
  • model_path (str) – Path to saved model file (.zip)

  • deterministic (bool) – Whether to use deterministic actions

model: BaseAlgorithm

The loaded SB3 model.

lstm_state: Tuple | None

Hidden state for recurrent models.

select_action(obs, valid_actions)

Select action using the trained model.

Parameters:
  • obs – Current observation

  • valid_actions – Valid actions (used for masking if needed)

Returns:

Action from model prediction

Return type:

int

reset()

Reset LSTM hidden state.

classmethod load(model_path)

Load a model from file.

Parameters:

model_path – Path to model file

Returns:

New SB3Agent instance

Return type:

SB3Agent

DQNAgent Class

class src.agents.DQNAgent(model_path=None, state_size=17, action_size=9)

Deep Q-Network agent.

Parameters:
  • model_path – Path to saved model (optional)

  • state_size – Observation dimension

  • action_size – Number of actions

epsilon: float

Exploration rate for epsilon-greedy.

memory: deque

Experience replay buffer.

select_action(obs, valid_actions)

Select action using epsilon-greedy policy.

train(batch_size=64)

Train on a batch from replay buffer.

save(path)

Save model weights.

load(path)

Load model weights.

Recurrent Agents

For LSTM-based agents, the hidden state persists across timesteps:

class RecurrentAgent(SB3Agent):
    """Agent with LSTM memory."""

    def __init__(self, model_path):
        super().__init__(model_path)
        self.lstm_state = None

    def select_action(self, obs, valid_actions):
        action, self.lstm_state = self.model.predict(
            obs,
            state=self.lstm_state,
            deterministic=True
        )
        return action

    def reset(self):
        self.lstm_state = None

Example Usage

Random Agent

from src.agents import RandomAgent
from src.state_action_reward import UnoEnv

env = UnoEnv()
agent = RandomAgent()

obs, _ = env.reset()
done = False

while not done:
    valid = env.get_valid_actions()
    action = agent.select_action(obs, valid)
    obs, reward, done, _, _ = env.step(action)

SB3 Agent

from src.agents import SB3Agent
from src.state_action_reward import UnoEnv

env = UnoEnv()
agent = SB3Agent("models/selfplay_champion.zip")

obs, _ = env.reset()
agent.reset()  # Clear LSTM state
done = False

while not done:
    valid = env.get_valid_actions()
    action = agent.select_action(obs, valid)
    obs, reward, done, _, _ = env.step(action)

print(f"Result: {'Win' if reward > 0 else 'Loss'}")

Comparing Agents

from src.agents import SB3Agent, RandomAgent
from src.state_action_reward import UnoEnv

def evaluate(agent1, agent2, env, num_games=100):
    wins = 0
    for _ in range(num_games):
        obs, _ = env.reset()
        agent1.reset()
        done = False

        while not done:
            action = agent1.select_action(obs, env.get_valid_actions())
            obs, reward, done, _, _ = env.step(action)

        if reward > 0:
            wins += 1

    return wins / num_games

env = UnoEnv()
champion = SB3Agent("models/selfplay_champion.zip")
random = RandomAgent()

win_rate = evaluate(champion, random, env)
print(f"Champion win rate: {win_rate:.1%}")

Agent Comparison Table

Agent

Type

Win Rate

Notes

SB3Agent (Self-Play)

LSTM

70%+

Best performance

SB3Agent (RecPPO)

LSTM

60%

Good baseline

SB3Agent (PPO)

MLP

53%

No memory

DQNAgent

MLP

48%

Value-based

RandomAgent

None

25%

Baseline