We will now start diving into Deep Q-Network (DQN) to train an agent to play GridWorld, which is a simple text-based game. There is a 4 x 4 grid of tiles and four objects are placed. There is an agent (a player), a pit, a goal, and a wall.
GridWorld project structure
The project has the following structure:
- DeepQNetwork.java: Provides the reference architecture for the DQN
- Replay.java: Generates replay memory for the DQN to ensure that the gradients of the deep network are stable and do not diverge across episodes
- GridWorld.java: The main class used for training the DQN and playing the game.
By the way, we perform the training on GPU and cuDNN for faster convergence. However, feel free to use the CPU backend as well if your machine does not have a GPU.