🎉 first commit
This commit is contained in:
commit
9823a78177
5 changed files with 476 additions and 0 deletions
76
README.md
Normal file
76
README.md
Normal file
|
@ -0,0 +1,76 @@
|
||||||
|
# TicTacToe
|
||||||
|
|
||||||
|
This repository is a simple implementation of the game of TicTacToe and some
|
||||||
|
experiments around Reinforcement Learning with it.
|
||||||
|
|
||||||
|
## Structure
|
||||||
|
|
||||||
|
* `game.py` contains the implementation of the game itself.
|
||||||
|
* `generate_board_hash_list.py` creates a pickle object containing a list of hash
|
||||||
|
for every possible unique non-ending variation of the board. It is useful to
|
||||||
|
create the Q-table latter and need to be precomputed.
|
||||||
|
* `q_learning.py` contains some experimentation with Q-learning using the
|
||||||
|
TicTacToe game as an exemple.
|
||||||
|
|
||||||
|
## Implementation details
|
||||||
|
|
||||||
|
The TicTacToe game is a Python Class. The board is a 3x3 ndarray (numpy) of the
|
||||||
|
dtype `int`. Input it taken from 1 to 9 following this scheme:
|
||||||
|
|
||||||
|
```
|
||||||
|
+---+---+---+
|
||||||
|
| 1 | 2 | 3 |
|
||||||
|
+---+---+---+
|
||||||
|
| 4 | 5 | 6 |
|
||||||
|
+---+---+---+
|
||||||
|
| 7 | 8 | 9 |
|
||||||
|
+---+---+---+
|
||||||
|
```
|
||||||
|
It is automatically raveled/unravaled when necessary.
|
||||||
|
|
||||||
|
We only need to check if there is a win above 5 moves because it impossible to
|
||||||
|
have a winner below this limit. At 9 moves the board is full and the game is
|
||||||
|
considered draw if no one won.
|
||||||
|
|
||||||
|
## Combinatorics
|
||||||
|
|
||||||
|
Without taking into account anything, we can estimate the upper bound of the
|
||||||
|
number of possible boards. There is $ 3^9 = 19683 $ possibilites.
|
||||||
|
|
||||||
|
There are 8 different symetries possibles (dihedral group of order 8, aka the
|
||||||
|
symetry group of the square). This drastically reduce the number of possible
|
||||||
|
boards.
|
||||||
|
|
||||||
|
Taking into account the symetries and the impossible boards (more O than X for
|
||||||
|
example), we get $765$ boards.
|
||||||
|
|
||||||
|
Since we do not need to store the last board in the DAG, this number drops to
|
||||||
|
$627$ non-ending boards.
|
||||||
|
|
||||||
|
This make our state space size to be $627$ and our action space size to be $9$.
|
||||||
|
|
||||||
|
## Reward
|
||||||
|
|
||||||
|
* `+1` for the winning side
|
||||||
|
* `-1` for the losign side
|
||||||
|
* `±0` in case of draw
|
||||||
|
|
||||||
|
The reward are given only at the end of an episode, when the winner is
|
||||||
|
determined. We backtract over all the states and moves to update the Q-table,
|
||||||
|
given the appropriate reward for each player.
|
||||||
|
Since the learning is episodic it can only be done at the end.
|
||||||
|
|
||||||
|
The learning rate $\alpha$ is set to $1$ because the game if fully
|
||||||
|
deterministic.
|
||||||
|
|
||||||
|
We use an $\varepsilon$-greedy (expentionnally decreasing) strategy for
|
||||||
|
exploration/exploitation.
|
||||||
|
|
||||||
|
The Bellman equation is simplified to the bare minimum for the special case of
|
||||||
|
an episodic, deterministic, 2 player game.
|
||||||
|
|
||||||
|
Maybe some reward shaping could be done to get better result and we would also
|
||||||
|
try a more complete version of the Bellman equation by considering Q[s+1,a]
|
||||||
|
which we do not right now. This would necessitate to handle the special case of
|
||||||
|
the winning board, which are not stored in order to reduce the state space
|
||||||
|
size.
|
BIN
board_hash_list.pkl
Normal file
BIN
board_hash_list.pkl
Normal file
Binary file not shown.
119
game.py
Normal file
119
game.py
Normal file
|
@ -0,0 +1,119 @@
|
||||||
|
# implement a simple text based tic-tac-toe game
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
class TicTacToe():
|
||||||
|
|
||||||
|
def __init__(self, size=(3,3)):
|
||||||
|
self.size = size
|
||||||
|
self.board = np.zeros(size, dtype=int)
|
||||||
|
self.player = 1 # the player whoes turn it is to play
|
||||||
|
self.winner = 0 # the winner, 0 for draw
|
||||||
|
self.done = 0 # if the game is done or not
|
||||||
|
self.turn = 0 # the current turn, maximum is 9
|
||||||
|
|
||||||
|
def input_is_valid(self, move):
|
||||||
|
"""
|
||||||
|
Check if the move one of the the players want to take is a valid move in
|
||||||
|
the current situation.
|
||||||
|
|
||||||
|
The move is valid if there is nothing where the player want to draw and
|
||||||
|
if the move is inside the grid.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# todo catch error if out of bounds
|
||||||
|
unraveled_index = np.unravel_index(move, self.size)
|
||||||
|
|
||||||
|
if self.board[unraveled_index] == 0:
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def update_board(self, move):
|
||||||
|
"""
|
||||||
|
Update board if move is valid
|
||||||
|
"""
|
||||||
|
|
||||||
|
if self.input_is_valid(move):
|
||||||
|
unraveled_index = np.unravel_index(move, self.size)
|
||||||
|
self.board[unraveled_index] = self.player
|
||||||
|
|
||||||
|
# change player
|
||||||
|
self.player = 3 - self.player
|
||||||
|
|
||||||
|
# update turn counter
|
||||||
|
self.turn += 1
|
||||||
|
|
||||||
|
def check_win(self):
|
||||||
|
"""
|
||||||
|
Check if the current board is in a winning position.
|
||||||
|
Return 0 if it is not and the player id (1 or 2) if it is.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# impossible to win in less than 5 moves
|
||||||
|
if self.turn < 5:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
# check rows
|
||||||
|
rows = self.board
|
||||||
|
for r in rows:
|
||||||
|
if (r == 1).all():
|
||||||
|
self.winner = 1
|
||||||
|
if (r == 2).all():
|
||||||
|
self.winner = 2
|
||||||
|
|
||||||
|
# check columns
|
||||||
|
columns = self.board.T
|
||||||
|
for c in columns:
|
||||||
|
if (c == 1).all():
|
||||||
|
self.winner = 1
|
||||||
|
if (c == 2).all():
|
||||||
|
self.winner = 2
|
||||||
|
|
||||||
|
# check diagonals
|
||||||
|
diagonals = [self.board.diagonal(), np.fliplr(self.board).diagonal()]
|
||||||
|
for d in diagonals:
|
||||||
|
if (d == 1).all():
|
||||||
|
self.winner = 1
|
||||||
|
if (d == 2).all():
|
||||||
|
self.winner = 2
|
||||||
|
|
||||||
|
# handle draw
|
||||||
|
if self.turn == 9:
|
||||||
|
self.done = 1
|
||||||
|
|
||||||
|
# if someone won
|
||||||
|
if self.winner != 0:
|
||||||
|
self.done = 1
|
||||||
|
|
||||||
|
# if no winning conidition has been found
|
||||||
|
return self.winner
|
||||||
|
|
||||||
|
def display_board(self):
|
||||||
|
"""
|
||||||
|
Display a humanly readable board.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
+---+---+---+
|
||||||
|
| X | O | |
|
||||||
|
+---+---+---+
|
||||||
|
| O | X | |
|
||||||
|
+---+---+---+
|
||||||
|
| X | | O |
|
||||||
|
+---+---+---+
|
||||||
|
"""
|
||||||
|
|
||||||
|
line_decorator = "+---+---+---+"
|
||||||
|
board_decorator = "| {} | {} | {} |"
|
||||||
|
|
||||||
|
convert = {0: " ", 1: "X", 2: "O"}
|
||||||
|
|
||||||
|
for r in self.board:
|
||||||
|
print(line_decorator)
|
||||||
|
|
||||||
|
L = []
|
||||||
|
for i in r:
|
||||||
|
L.append(convert[i])
|
||||||
|
|
||||||
|
print(board_decorator.format(*L))
|
||||||
|
|
||||||
|
print(line_decorator)
|
74
generate_board_hash_list.py
Normal file
74
generate_board_hash_list.py
Normal file
|
@ -0,0 +1,74 @@
|
||||||
|
from game import TicTacToe
|
||||||
|
from tqdm import tqdm
|
||||||
|
import numpy as np
|
||||||
|
from random import randint
|
||||||
|
from hashlib import sha1
|
||||||
|
import pickle
|
||||||
|
|
||||||
|
board_list = []
|
||||||
|
|
||||||
|
for k in tqdm(range(10000)):
|
||||||
|
#tqdm.write(f"Game {k}")
|
||||||
|
ttt = TicTacToe()
|
||||||
|
#tqdm.write(str(len(board_list)))
|
||||||
|
while not ttt.done:
|
||||||
|
#tqdm.write(f"turn: {ttt.turn}")
|
||||||
|
#ttt.display_board()
|
||||||
|
# compute all the symetries
|
||||||
|
b1 = ttt.board
|
||||||
|
b2 = np.rot90(b1, 1)
|
||||||
|
b3 = np.rot90(b1, 2)
|
||||||
|
b4 = np.rot90(b1, 3)
|
||||||
|
b5 = np.fliplr(b1)
|
||||||
|
b6 = np.flipud(b1)
|
||||||
|
b7 = b1.T # mirror diagonally
|
||||||
|
b8 = np.fliplr(b2) # mirror anti-diagonally
|
||||||
|
|
||||||
|
# compute all the hash of the symetries
|
||||||
|
list_hd = []
|
||||||
|
for b in [b1, b2, b3, b4, b5, b6, b7, b8]:
|
||||||
|
#list_hd.append()
|
||||||
|
hd = sha1(np.ascontiguousarray(b)).hexdigest()
|
||||||
|
if hd in board_list:
|
||||||
|
break
|
||||||
|
if hd not in board_list:
|
||||||
|
board_list.append(hd)
|
||||||
|
|
||||||
|
# choose randomly
|
||||||
|
flag = False
|
||||||
|
while flag is not True:
|
||||||
|
move = randint(1,9) - 1
|
||||||
|
flag = ttt.input_is_valid(move)
|
||||||
|
|
||||||
|
# choose from the list of available moves
|
||||||
|
#zeros = np.where(ttt.board == 0)
|
||||||
|
#a = len(zeros[0])
|
||||||
|
#i = randint(0, a-1)
|
||||||
|
#move = np.ravel_multi_index((zeros[0][i], zeros[1][i]), (3,3))
|
||||||
|
# not faster than the random method above
|
||||||
|
# the method above is easier to understand
|
||||||
|
|
||||||
|
#tqdm.write(str(move))
|
||||||
|
ttt.update_board(move)
|
||||||
|
ttt.winner = ttt.check_win()
|
||||||
|
|
||||||
|
# premature ending as soon as the 627 possibilites have been found
|
||||||
|
if len(board_list) == 627:
|
||||||
|
tqdm.write(f"breaking at {k}")
|
||||||
|
break
|
||||||
|
|
||||||
|
# number of non-ending boards (sorted by number of turns)
|
||||||
|
# {0: 1, 1: 3, 2: 12, 3: 38, 4: 108, 5: 153, 6: 183, 7: 95, 8: 34, 9: 0}
|
||||||
|
|
||||||
|
# number of all board (sorted by number of turns)
|
||||||
|
# {0: 1, 1: 3, 2: 12, 3: 38, 4: 108, 5: 174, 6: 204, 7: 153, 8: 57, 9: 15}
|
||||||
|
|
||||||
|
# (it should always be 627)
|
||||||
|
print(f"Number of different (non-ending) boards: {len(board_list)}")
|
||||||
|
|
||||||
|
# Dump the pickle obj
|
||||||
|
fd = open("board_hash_list.pkl", "wb")
|
||||||
|
pickle.dump(board_list, fd)
|
||||||
|
fd.close()
|
||||||
|
|
||||||
|
|
207
q_learning.py
Normal file
207
q_learning.py
Normal file
|
@ -0,0 +1,207 @@
|
||||||
|
from game import TicTacToe
|
||||||
|
from tqdm import tqdm
|
||||||
|
import numpy as np
|
||||||
|
from random import randint, uniform
|
||||||
|
from hashlib import sha1
|
||||||
|
import pickle
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
# Retrive the board_set from the pickled data
|
||||||
|
fd = open("pickle_board_set.pkl", "rb")
|
||||||
|
board_set = pickle.load(fd)
|
||||||
|
fd.close()
|
||||||
|
|
||||||
|
# Make it a list
|
||||||
|
board_list = list(board_set)
|
||||||
|
|
||||||
|
def state_index_from_board(board):
|
||||||
|
"""
|
||||||
|
Return the Q-table index of a given board.
|
||||||
|
|
||||||
|
For any board represented by a 3x3 ndarray, return the index in the Q-table
|
||||||
|
of the corresponding state. Takes into account the 8 symetries.
|
||||||
|
|
||||||
|
There are 627 different non-ending states in the Q-table.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Compute all the symetries
|
||||||
|
b1 = board
|
||||||
|
b2 = np.rot90(b1, 1)
|
||||||
|
b3 = np.rot90(b1, 2)
|
||||||
|
b4 = np.rot90(b1, 3)
|
||||||
|
b5 = np.fliplr(b1)
|
||||||
|
b6 = np.flipud(b1)
|
||||||
|
b7 = b1.T # mirror diagonally
|
||||||
|
b8 = np.fliplr(b2) # mirror anti-diagonally
|
||||||
|
|
||||||
|
# Compute the hash of the symetries and find the index
|
||||||
|
for b in [b1, b2, b3, b4, b5, b6, b7, b8]:
|
||||||
|
hd = sha1(np.ascontiguousarray(b)).hexdigest()
|
||||||
|
if hd in board_list:
|
||||||
|
state_index = board_list.index(hd)
|
||||||
|
ttt.board = b # rotate the board to the canonical value
|
||||||
|
|
||||||
|
return state_index
|
||||||
|
|
||||||
|
def exploitation(ttt, Q):
|
||||||
|
"""
|
||||||
|
Exploit the Q-table.
|
||||||
|
|
||||||
|
Fully exploit the knowledge from the Q-table to retrive the next action to
|
||||||
|
take in the given state.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
==========
|
||||||
|
ttt: a game, i.e. an instance of the TicTacToe class
|
||||||
|
Q: the Q-table
|
||||||
|
"""
|
||||||
|
|
||||||
|
state_index = state_index_from_board(ttt.board)
|
||||||
|
|
||||||
|
if ttt.player == 1:
|
||||||
|
a = np.argmax(Q[state_index, :])
|
||||||
|
else:
|
||||||
|
a = np.argmin(Q[state_index, :])
|
||||||
|
|
||||||
|
return a
|
||||||
|
|
||||||
|
|
||||||
|
def random_agent(ttt):
|
||||||
|
flag = False
|
||||||
|
while flag is not True:
|
||||||
|
move = randint(1,9) - 1
|
||||||
|
flag = ttt.input_is_valid(move)
|
||||||
|
return move
|
||||||
|
|
||||||
|
# initialisation of the Q-table
|
||||||
|
state_size = 627 # harcoded, calculated from experiment
|
||||||
|
action_size = 9 # the 9 possible moves at each turn
|
||||||
|
Q = np.zeros((state_size, action_size))
|
||||||
|
|
||||||
|
# hyper parameters
|
||||||
|
num_episodes = 10000
|
||||||
|
max_test = 9 # max number of steps per episode
|
||||||
|
alpha = 1 # learning rate (problem is deterministic)
|
||||||
|
|
||||||
|
# exploration / exploitation parameters
|
||||||
|
epsilon = 1
|
||||||
|
max_epsilon = 1
|
||||||
|
min_epsilon = 0.01
|
||||||
|
decay_rate = 0.001
|
||||||
|
|
||||||
|
winner_history = []
|
||||||
|
reward_history = []
|
||||||
|
exploit_history = []
|
||||||
|
epsilon_history = []
|
||||||
|
|
||||||
|
for k in tqdm(range(num_episodes)):
|
||||||
|
|
||||||
|
list_state_move = []
|
||||||
|
|
||||||
|
# reset the board
|
||||||
|
ttt = TicTacToe()
|
||||||
|
|
||||||
|
while not ttt.done:
|
||||||
|
#ttt.display_board()
|
||||||
|
# get state and rotate the board in the canonical way
|
||||||
|
state_index = state_index_from_board(ttt.board)
|
||||||
|
#print(f"state_index = {state_index}")
|
||||||
|
#print(Q[state_index])
|
||||||
|
|
||||||
|
# explore or exploit?
|
||||||
|
tradeoff = uniform(0,1)
|
||||||
|
|
||||||
|
do_exploit = tradeoff > epsilon
|
||||||
|
# debug: never exploit
|
||||||
|
do_exploit = False
|
||||||
|
|
||||||
|
if do_exploit:
|
||||||
|
# exploit
|
||||||
|
move = exploitation(ttt, Q)
|
||||||
|
if not ttt.input_is_valid(move):
|
||||||
|
move = random_agent(ttt)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Random Agent (exploration)
|
||||||
|
move = random_agent(ttt)
|
||||||
|
|
||||||
|
# remember the couples (s,a) to give reward at the end of the episode
|
||||||
|
list_state_move.append((state_index, move))
|
||||||
|
|
||||||
|
ttt.update_board(move)
|
||||||
|
ttt.winner = ttt.check_win()
|
||||||
|
|
||||||
|
|
||||||
|
if k % 1000 == 0:
|
||||||
|
tqdm.write(f"epsiode: {k}, epsilon: {epsilon:0.3f}, winner: {ttt.winner}, \
|
||||||
|
exploited: {tradeoff > epsilon}")
|
||||||
|
|
||||||
|
# reward shaping
|
||||||
|
if ttt.winner == 1:
|
||||||
|
r = 1
|
||||||
|
elif ttt.winner == 2:
|
||||||
|
r = -1
|
||||||
|
else:
|
||||||
|
r = 0 # draw
|
||||||
|
|
||||||
|
#print(r)
|
||||||
|
reward_history.append(r)
|
||||||
|
|
||||||
|
# Update Q-table (not yet uising the Bellman equation, case is too simple)
|
||||||
|
for s, a in list_state_move:
|
||||||
|
Q[s,a] += alpha * r
|
||||||
|
r *= -1 # inverse the reward for the next move due to player inversion
|
||||||
|
|
||||||
|
# Update the epsilon-greedy (decreasing) strategy
|
||||||
|
epsilon = min_epsilon + (max_epsilon - min_epsilon)*np.exp(-decay_rate*k)
|
||||||
|
|
||||||
|
# remember stats for history
|
||||||
|
winner_history.append(ttt.winner)
|
||||||
|
epsilon_history.append(epsilon)
|
||||||
|
exploit_history.append(tradeoff > epsilon)
|
||||||
|
# input()
|
||||||
|
#import code
|
||||||
|
#code.interact(local=locals())
|
||||||
|
|
||||||
|
# Helper functions
|
||||||
|
|
||||||
|
def player_move(ttt):
|
||||||
|
ttt.display_board()
|
||||||
|
state_index_from_board(ttt.board)
|
||||||
|
ttt.display_board()
|
||||||
|
print("Human's turn")
|
||||||
|
|
||||||
|
flag = False
|
||||||
|
while flag is not True:
|
||||||
|
move = int(input("What is your move? [1-9] "))
|
||||||
|
move -= 1 # range is 0-8 in array
|
||||||
|
flag = ttt.input_is_valid(move)
|
||||||
|
|
||||||
|
ttt.update_board(move)
|
||||||
|
ttt.winner = ttt.check_win()
|
||||||
|
ttt.display_board()
|
||||||
|
|
||||||
|
def ai_move(ttt, Q):
|
||||||
|
ttt.display_board()
|
||||||
|
print("AI's turn")
|
||||||
|
|
||||||
|
move = exploitation(ttt, Q)
|
||||||
|
|
||||||
|
ttt.input_is_valid(move)
|
||||||
|
ttt.update_board(move)
|
||||||
|
ttt.winner = ttt.check_win()
|
||||||
|
ttt.display_board()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# plot graph
|
||||||
|
cumulative_win = np.cumsum(reward_history)
|
||||||
|
plt.plot(cumulative_win)
|
||||||
|
plt.title("Cumulative reward")
|
||||||
|
plt.xlabel("epochs")
|
||||||
|
plt.ylabel("cumsum of rewards")
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
plt.plot(epsilon_history)
|
||||||
|
plt.title("Epsilon history")
|
||||||
|
plt.show()
|
Loading…
Reference in a new issue