Training an Unbeatable Tic-Tac-Toe AI using Reinforcement Learning
Introduction
The Goal
In this article, we will train a neural network to play the classic game of tic-tac-toe. The network will be trained using a methodology known as reinforcement learning, one of the main machine learning paradigms. Our code will be written in Python and leverage the Ray RlLib package for its rich feature support and industry-grade performance. By the end, we will have created an AI model that never loses and runs in real-time on any platform supporting the Open Neural Network Exchange (ONNX) file format.
This article assumes you are comfortable with Python and have a decent handle on machine learning concepts. There are plenty of great resources to learn all about machine learning, neural networks, gradient descent, etc. If you find that I am referencing something you don't understand, a quick Google search should help you out. If you don't have python up and running, a Google search can also help you with that. Optionally, you can check out my article to get started.
Our Approach
To create our unbeatable tic-tac-toe AI, we will need to:
- Design an environment in which two agents can play tic-tac-toe by taking actions and receiving reward
- Write scaffolding code to manage training, hyperparameter optimization, and checkpoint handling
- Identify the best model and export to a transferable format (ONNX)
- Evaluate model performance to ensure it does not lose
Our environment will be built using PettingZoo, a library capable of representing generalized multi-agent RL problems. The model itself is a neural network implemented in PyTorch. However, we won't be directly calling PyTorch APIs to train the model. Instead, we will use the Ray RlLib library to train the model for us, converting network outputs into actions and using subsequent rewards to update the weights.
We will also use other parts of the Ray ecosystem to monitor training and allocate resources (compute time) to the most promising combination of hyperparameters. Once we have a desirable model, we will export the PyTorch neural network to ONNX format. Finally, we can test the model to confirm that it never loses.
Reinforcement Learning Crash Course
Before we dive in, it would be helpful to understand the gist of reinforcement learning. So what is it? Reinforcement learning teaches an agent how to make decisions that maximize the reward signal it receives from the environment. Let's break that down:
- One or more agents live in an environment
- Agents observe things in their environment and use this information to make decisions based on a policy
- The environment rewards agents based on their actions (reward can be negative)
- Agents update their policy based on the reward they receive
- This cycle of observations, actions, rewards, and updates continues until we reach a stopping condition
The policy of an agent determines how it makes decisions, and is ultimately what we are trying to train. This is the part that is represented by a neural network. As with any good challenge, there are a few difficulties that we must overcome when training an agent with RL:
- The consequences of an action may not be immediately received. The agent must be able to understand how actions taken now impact the possibility of receiving rewards in the future.
- Consider an agent whose policy is good, but could be better. The agent must decide if it will stick with the strategy it knows to produce acceptable reward, or explore new modifications to its strategy at the risk of missing out on reward it could have received.
Thankfully, machine learning researchers and engineers have developed algorithms and strategies that address these issues. Some popular RL algorithms include Proximal Policy Optimization (PPO), Deep Q Networks (DQN), Soft Actor Critic (SAC), and IMPALA. For more details on RL and commonly used algorithms, take a glance at this paper.