🎉 first commit
This commit is contained in:
commit
9823a78177
5 changed files with 476 additions and 0 deletions
76
README.md
Normal file
76
README.md
Normal file
|
@ -0,0 +1,76 @@
|
|||
# TicTacToe
|
||||
|
||||
This repository is a simple implementation of the game of TicTacToe and some
|
||||
experiments around Reinforcement Learning with it.
|
||||
|
||||
## Structure
|
||||
|
||||
* `game.py` contains the implementation of the game itself.
|
||||
* `generate_board_hash_list.py` creates a pickle object containing a list of hash
|
||||
for every possible unique non-ending variation of the board. It is useful to
|
||||
create the Q-table latter and need to be precomputed.
|
||||
* `q_learning.py` contains some experimentation with Q-learning using the
|
||||
TicTacToe game as an exemple.
|
||||
|
||||
## Implementation details
|
||||
|
||||
The TicTacToe game is a Python Class. The board is a 3x3 ndarray (numpy) of the
|
||||
dtype `int`. Input it taken from 1 to 9 following this scheme:
|
||||
|
||||
```
|
||||
+---+---+---+
|
||||
| 1 | 2 | 3 |
|
||||
+---+---+---+
|
||||
| 4 | 5 | 6 |
|
||||
+---+---+---+
|
||||
| 7 | 8 | 9 |
|
||||
+---+---+---+
|
||||
```
|
||||
It is automatically raveled/unravaled when necessary.
|
||||
|
||||
We only need to check if there is a win above 5 moves because it impossible to
|
||||
have a winner below this limit. At 9 moves the board is full and the game is
|
||||
considered draw if no one won.
|
||||
|
||||
## Combinatorics
|
||||
|
||||
Without taking into account anything, we can estimate the upper bound of the
|
||||
number of possible boards. There is $ 3^9 = 19683 $ possibilites.
|
||||
|
||||
There are 8 different symetries possibles (dihedral group of order 8, aka the
|
||||
symetry group of the square). This drastically reduce the number of possible
|
||||
boards.
|
||||
|
||||
Taking into account the symetries and the impossible boards (more O than X for
|
||||
example), we get $765$ boards.
|
||||
|
||||
Since we do not need to store the last board in the DAG, this number drops to
|
||||
$627$ non-ending boards.
|
||||
|
||||
This make our state space size to be $627$ and our action space size to be $9$.
|
||||
|
||||
## Reward
|
||||
|
||||
* `+1` for the winning side
|
||||
* `-1` for the losign side
|
||||
* `±0` in case of draw
|
||||
|
||||
The reward are given only at the end of an episode, when the winner is
|
||||
determined. We backtract over all the states and moves to update the Q-table,
|
||||
given the appropriate reward for each player.
|
||||
Since the learning is episodic it can only be done at the end.
|
||||
|
||||
The learning rate $\alpha$ is set to $1$ because the game if fully
|
||||
deterministic.
|
||||
|
||||
We use an $\varepsilon$-greedy (expentionnally decreasing) strategy for
|
||||
exploration/exploitation.
|
||||
|
||||
The Bellman equation is simplified to the bare minimum for the special case of
|
||||
an episodic, deterministic, 2 player game.
|
||||
|
||||
Maybe some reward shaping could be done to get better result and we would also
|
||||
try a more complete version of the Bellman equation by considering Q[s+1,a]
|
||||
which we do not right now. This would necessitate to handle the special case of
|
||||
the winning board, which are not stored in order to reduce the state space
|
||||
size.
|
BIN
board_hash_list.pkl
Normal file
BIN
board_hash_list.pkl
Normal file
Binary file not shown.
119
game.py
Normal file
119
game.py
Normal file
|
@ -0,0 +1,119 @@
|
|||
# implement a simple text based tic-tac-toe game
|
||||
import numpy as np
|
||||
|
||||
class TicTacToe():
|
||||
|
||||
def __init__(self, size=(3,3)):
|
||||
self.size = size
|
||||
self.board = np.zeros(size, dtype=int)
|
||||
self.player = 1 # the player whoes turn it is to play
|
||||
self.winner = 0 # the winner, 0 for draw
|
||||
self.done = 0 # if the game is done or not
|
||||
self.turn = 0 # the current turn, maximum is 9
|
||||
|
||||
def input_is_valid(self, move):
|
||||
"""
|
||||
Check if the move one of the the players want to take is a valid move in
|
||||
the current situation.
|
||||
|
||||
The move is valid if there is nothing where the player want to draw and
|
||||
if the move is inside the grid.
|
||||
"""
|
||||
|
||||
# todo catch error if out of bounds
|
||||
unraveled_index = np.unravel_index(move, self.size)
|
||||
|
||||
if self.board[unraveled_index] == 0:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def update_board(self, move):
|
||||
"""
|
||||
Update board if move is valid
|
||||
"""
|
||||
|
||||
if self.input_is_valid(move):
|
||||
unraveled_index = np.unravel_index(move, self.size)
|
||||
self.board[unraveled_index] = self.player
|
||||
|
||||
# change player
|
||||
self.player = 3 - self.player
|
||||
|
||||
# update turn counter
|
||||
self.turn += 1
|
||||
|
||||
def check_win(self):
|
||||
"""
|
||||
Check if the current board is in a winning position.
|
||||
Return 0 if it is not and the player id (1 or 2) if it is.
|
||||
"""
|
||||
|
||||
# impossible to win in less than 5 moves
|
||||
if self.turn < 5:
|
||||
return 0
|
||||
|
||||
# check rows
|
||||
rows = self.board
|
||||
for r in rows:
|
||||
if (r == 1).all():
|
||||
self.winner = 1
|
||||
if (r == 2).all():
|
||||
self.winner = 2
|
||||
|
||||
# check columns
|
||||
columns = self.board.T
|
||||
for c in columns:
|
||||
if (c == 1).all():
|
||||
self.winner = 1
|
||||
if (c == 2).all():
|
||||
self.winner = 2
|
||||
|
||||
# check diagonals
|
||||
diagonals = [self.board.diagonal(), np.fliplr(self.board).diagonal()]
|
||||
for d in diagonals:
|
||||
if (d == 1).all():
|
||||
self.winner = 1
|
||||
if (d == 2).all():
|
||||
self.winner = 2
|
||||
|
||||
# handle draw
|
||||
if self.turn == 9:
|
||||
self.done = 1
|
||||
|
||||
# if someone won
|
||||
if self.winner != 0:
|
||||
self.done = 1
|
||||
|
||||
# if no winning conidition has been found
|
||||
return self.winner
|
||||
|
||||
def display_board(self):
|
||||
"""
|
||||
Display a humanly readable board.
|
||||
|
||||
Example:
|
||||
+---+---+---+
|
||||
| X | O | |
|
||||
+---+---+---+
|
||||
| O | X | |
|
||||
+---+---+---+
|
||||
| X | | O |
|
||||
+---+---+---+
|
||||
"""
|
||||
|
||||
line_decorator = "+---+---+---+"
|
||||
board_decorator = "| {} | {} | {} |"
|
||||
|
||||
convert = {0: " ", 1: "X", 2: "O"}
|
||||
|
||||
for r in self.board:
|
||||
print(line_decorator)
|
||||
|
||||
L = []
|
||||
for i in r:
|
||||
L.append(convert[i])
|
||||
|
||||
print(board_decorator.format(*L))
|
||||
|
||||
print(line_decorator)
|
74
generate_board_hash_list.py
Normal file
74
generate_board_hash_list.py
Normal file
|
@ -0,0 +1,74 @@
|
|||
from game import TicTacToe
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
from random import randint
|
||||
from hashlib import sha1
|
||||
import pickle
|
||||
|
||||
board_list = []
|
||||
|
||||
for k in tqdm(range(10000)):
|
||||
#tqdm.write(f"Game {k}")
|
||||
ttt = TicTacToe()
|
||||
#tqdm.write(str(len(board_list)))
|
||||
while not ttt.done:
|
||||
#tqdm.write(f"turn: {ttt.turn}")
|
||||
#ttt.display_board()
|
||||
# compute all the symetries
|
||||
b1 = ttt.board
|
||||
b2 = np.rot90(b1, 1)
|
||||
b3 = np.rot90(b1, 2)
|
||||
b4 = np.rot90(b1, 3)
|
||||
b5 = np.fliplr(b1)
|
||||
b6 = np.flipud(b1)
|
||||
b7 = b1.T # mirror diagonally
|
||||
b8 = np.fliplr(b2) # mirror anti-diagonally
|
||||
|
||||
# compute all the hash of the symetries
|
||||
list_hd = []
|
||||
for b in [b1, b2, b3, b4, b5, b6, b7, b8]:
|
||||
#list_hd.append()
|
||||
hd = sha1(np.ascontiguousarray(b)).hexdigest()
|
||||
if hd in board_list:
|
||||
break
|
||||
if hd not in board_list:
|
||||
board_list.append(hd)
|
||||
|
||||
# choose randomly
|
||||
flag = False
|
||||
while flag is not True:
|
||||
move = randint(1,9) - 1
|
||||
flag = ttt.input_is_valid(move)
|
||||
|
||||
# choose from the list of available moves
|
||||
#zeros = np.where(ttt.board == 0)
|
||||
#a = len(zeros[0])
|
||||
#i = randint(0, a-1)
|
||||
#move = np.ravel_multi_index((zeros[0][i], zeros[1][i]), (3,3))
|
||||
# not faster than the random method above
|
||||
# the method above is easier to understand
|
||||
|
||||
#tqdm.write(str(move))
|
||||
ttt.update_board(move)
|
||||
ttt.winner = ttt.check_win()
|
||||
|
||||
# premature ending as soon as the 627 possibilites have been found
|
||||
if len(board_list) == 627:
|
||||
tqdm.write(f"breaking at {k}")
|
||||
break
|
||||
|
||||
# number of non-ending boards (sorted by number of turns)
|
||||
# {0: 1, 1: 3, 2: 12, 3: 38, 4: 108, 5: 153, 6: 183, 7: 95, 8: 34, 9: 0}
|
||||
|
||||
# number of all board (sorted by number of turns)
|
||||
# {0: 1, 1: 3, 2: 12, 3: 38, 4: 108, 5: 174, 6: 204, 7: 153, 8: 57, 9: 15}
|
||||
|
||||
# (it should always be 627)
|
||||
print(f"Number of different (non-ending) boards: {len(board_list)}")
|
||||
|
||||
# Dump the pickle obj
|
||||
fd = open("board_hash_list.pkl", "wb")
|
||||
pickle.dump(board_list, fd)
|
||||
fd.close()
|
||||
|
||||
|
207
q_learning.py
Normal file
207
q_learning.py
Normal file
|
@ -0,0 +1,207 @@
|
|||
from game import TicTacToe
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
from random import randint, uniform
|
||||
from hashlib import sha1
|
||||
import pickle
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# Retrive the board_set from the pickled data
|
||||
fd = open("pickle_board_set.pkl", "rb")
|
||||
board_set = pickle.load(fd)
|
||||
fd.close()
|
||||
|
||||
# Make it a list
|
||||
board_list = list(board_set)
|
||||
|
||||
def state_index_from_board(board):
|
||||
"""
|
||||
Return the Q-table index of a given board.
|
||||
|
||||
For any board represented by a 3x3 ndarray, return the index in the Q-table
|
||||
of the corresponding state. Takes into account the 8 symetries.
|
||||
|
||||
There are 627 different non-ending states in the Q-table.
|
||||
"""
|
||||
|
||||
# Compute all the symetries
|
||||
b1 = board
|
||||
b2 = np.rot90(b1, 1)
|
||||
b3 = np.rot90(b1, 2)
|
||||
b4 = np.rot90(b1, 3)
|
||||
b5 = np.fliplr(b1)
|
||||
b6 = np.flipud(b1)
|
||||
b7 = b1.T # mirror diagonally
|
||||
b8 = np.fliplr(b2) # mirror anti-diagonally
|
||||
|
||||
# Compute the hash of the symetries and find the index
|
||||
for b in [b1, b2, b3, b4, b5, b6, b7, b8]:
|
||||
hd = sha1(np.ascontiguousarray(b)).hexdigest()
|
||||
if hd in board_list:
|
||||
state_index = board_list.index(hd)
|
||||
ttt.board = b # rotate the board to the canonical value
|
||||
|
||||
return state_index
|
||||
|
||||
def exploitation(ttt, Q):
|
||||
"""
|
||||
Exploit the Q-table.
|
||||
|
||||
Fully exploit the knowledge from the Q-table to retrive the next action to
|
||||
take in the given state.
|
||||
|
||||
Parameters
|
||||
==========
|
||||
ttt: a game, i.e. an instance of the TicTacToe class
|
||||
Q: the Q-table
|
||||
"""
|
||||
|
||||
state_index = state_index_from_board(ttt.board)
|
||||
|
||||
if ttt.player == 1:
|
||||
a = np.argmax(Q[state_index, :])
|
||||
else:
|
||||
a = np.argmin(Q[state_index, :])
|
||||
|
||||
return a
|
||||
|
||||
|
||||
def random_agent(ttt):
|
||||
flag = False
|
||||
while flag is not True:
|
||||
move = randint(1,9) - 1
|
||||
flag = ttt.input_is_valid(move)
|
||||
return move
|
||||
|
||||
# initialisation of the Q-table
|
||||
state_size = 627 # harcoded, calculated from experiment
|
||||
action_size = 9 # the 9 possible moves at each turn
|
||||
Q = np.zeros((state_size, action_size))
|
||||
|
||||
# hyper parameters
|
||||
num_episodes = 10000
|
||||
max_test = 9 # max number of steps per episode
|
||||
alpha = 1 # learning rate (problem is deterministic)
|
||||
|
||||
# exploration / exploitation parameters
|
||||
epsilon = 1
|
||||
max_epsilon = 1
|
||||
min_epsilon = 0.01
|
||||
decay_rate = 0.001
|
||||
|
||||
winner_history = []
|
||||
reward_history = []
|
||||
exploit_history = []
|
||||
epsilon_history = []
|
||||
|
||||
for k in tqdm(range(num_episodes)):
|
||||
|
||||
list_state_move = []
|
||||
|
||||
# reset the board
|
||||
ttt = TicTacToe()
|
||||
|
||||
while not ttt.done:
|
||||
#ttt.display_board()
|
||||
# get state and rotate the board in the canonical way
|
||||
state_index = state_index_from_board(ttt.board)
|
||||
#print(f"state_index = {state_index}")
|
||||
#print(Q[state_index])
|
||||
|
||||
# explore or exploit?
|
||||
tradeoff = uniform(0,1)
|
||||
|
||||
do_exploit = tradeoff > epsilon
|
||||
# debug: never exploit
|
||||
do_exploit = False
|
||||
|
||||
if do_exploit:
|
||||
# exploit
|
||||
move = exploitation(ttt, Q)
|
||||
if not ttt.input_is_valid(move):
|
||||
move = random_agent(ttt)
|
||||
|
||||
else:
|
||||
# Random Agent (exploration)
|
||||
move = random_agent(ttt)
|
||||
|
||||
# remember the couples (s,a) to give reward at the end of the episode
|
||||
list_state_move.append((state_index, move))
|
||||
|
||||
ttt.update_board(move)
|
||||
ttt.winner = ttt.check_win()
|
||||
|
||||
|
||||
if k % 1000 == 0:
|
||||
tqdm.write(f"epsiode: {k}, epsilon: {epsilon:0.3f}, winner: {ttt.winner}, \
|
||||
exploited: {tradeoff > epsilon}")
|
||||
|
||||
# reward shaping
|
||||
if ttt.winner == 1:
|
||||
r = 1
|
||||
elif ttt.winner == 2:
|
||||
r = -1
|
||||
else:
|
||||
r = 0 # draw
|
||||
|
||||
#print(r)
|
||||
reward_history.append(r)
|
||||
|
||||
# Update Q-table (not yet uising the Bellman equation, case is too simple)
|
||||
for s, a in list_state_move:
|
||||
Q[s,a] += alpha * r
|
||||
r *= -1 # inverse the reward for the next move due to player inversion
|
||||
|
||||
# Update the epsilon-greedy (decreasing) strategy
|
||||
epsilon = min_epsilon + (max_epsilon - min_epsilon)*np.exp(-decay_rate*k)
|
||||
|
||||
# remember stats for history
|
||||
winner_history.append(ttt.winner)
|
||||
epsilon_history.append(epsilon)
|
||||
exploit_history.append(tradeoff > epsilon)
|
||||
# input()
|
||||
#import code
|
||||
#code.interact(local=locals())
|
||||
|
||||
# Helper functions
|
||||
|
||||
def player_move(ttt):
|
||||
ttt.display_board()
|
||||
state_index_from_board(ttt.board)
|
||||
ttt.display_board()
|
||||
print("Human's turn")
|
||||
|
||||
flag = False
|
||||
while flag is not True:
|
||||
move = int(input("What is your move? [1-9] "))
|
||||
move -= 1 # range is 0-8 in array
|
||||
flag = ttt.input_is_valid(move)
|
||||
|
||||
ttt.update_board(move)
|
||||
ttt.winner = ttt.check_win()
|
||||
ttt.display_board()
|
||||
|
||||
def ai_move(ttt, Q):
|
||||
ttt.display_board()
|
||||
print("AI's turn")
|
||||
|
||||
move = exploitation(ttt, Q)
|
||||
|
||||
ttt.input_is_valid(move)
|
||||
ttt.update_board(move)
|
||||
ttt.winner = ttt.check_win()
|
||||
ttt.display_board()
|
||||
|
||||
|
||||
|
||||
# plot graph
|
||||
cumulative_win = np.cumsum(reward_history)
|
||||
plt.plot(cumulative_win)
|
||||
plt.title("Cumulative reward")
|
||||
plt.xlabel("epochs")
|
||||
plt.ylabel("cumsum of rewards")
|
||||
plt.show()
|
||||
|
||||
plt.plot(epsilon_history)
|
||||
plt.title("Epsilon history")
|
||||
plt.show()
|
Loading…
Reference in a new issue