commit 9823a781777022bcded709e014edaabb98bec198 Author: Otthorn Date: Mon May 17 00:50:03 2021 +0200 :tada: first commit diff --git a/README.md b/README.md new file mode 100644 index 0000000..227aa3a --- /dev/null +++ b/README.md @@ -0,0 +1,76 @@ +# TicTacToe + +This repository is a simple implementation of the game of TicTacToe and some +experiments around Reinforcement Learning with it. + +## Structure + +* `game.py` contains the implementation of the game itself. +* `generate_board_hash_list.py` creates a pickle object containing a list of hash +for every possible unique non-ending variation of the board. It is useful to +create the Q-table latter and need to be precomputed. +* `q_learning.py` contains some experimentation with Q-learning using the + TicTacToe game as an exemple. + +## Implementation details + +The TicTacToe game is a Python Class. The board is a 3x3 ndarray (numpy) of the +dtype `int`. Input it taken from 1 to 9 following this scheme: + +``` ++---+---+---+ +| 1 | 2 | 3 | ++---+---+---+ +| 4 | 5 | 6 | ++---+---+---+ +| 7 | 8 | 9 | ++---+---+---+ +``` +It is automatically raveled/unravaled when necessary. + +We only need to check if there is a win above 5 moves because it impossible to +have a winner below this limit. At 9 moves the board is full and the game is +considered draw if no one won. + +## Combinatorics + +Without taking into account anything, we can estimate the upper bound of the +number of possible boards. There is $ 3^9 = 19683 $ possibilites. + +There are 8 different symetries possibles (dihedral group of order 8, aka the +symetry group of the square). This drastically reduce the number of possible +boards. + +Taking into account the symetries and the impossible boards (more O than X for +example), we get $765$ boards. + +Since we do not need to store the last board in the DAG, this number drops to +$627$ non-ending boards. + +This make our state space size to be $627$ and our action space size to be $9$. + +## Reward + +* `+1` for the winning side +* `-1` for the losign side +* `±0` in case of draw + +The reward are given only at the end of an episode, when the winner is +determined. We backtract over all the states and moves to update the Q-table, +given the appropriate reward for each player. +Since the learning is episodic it can only be done at the end. + +The learning rate $\alpha$ is set to $1$ because the game if fully +deterministic. + +We use an $\varepsilon$-greedy (expentionnally decreasing) strategy for +exploration/exploitation. + +The Bellman equation is simplified to the bare minimum for the special case of +an episodic, deterministic, 2 player game. + +Maybe some reward shaping could be done to get better result and we would also +try a more complete version of the Bellman equation by considering Q[s+1,a] +which we do not right now. This would necessitate to handle the special case of +the winning board, which are not stored in order to reduce the state space +size. diff --git a/board_hash_list.pkl b/board_hash_list.pkl new file mode 100644 index 0000000..1c577a3 Binary files /dev/null and b/board_hash_list.pkl differ diff --git a/game.py b/game.py new file mode 100644 index 0000000..898b365 --- /dev/null +++ b/game.py @@ -0,0 +1,119 @@ +# implement a simple text based tic-tac-toe game +import numpy as np + +class TicTacToe(): + + def __init__(self, size=(3,3)): + self.size = size + self.board = np.zeros(size, dtype=int) + self.player = 1 # the player whoes turn it is to play + self.winner = 0 # the winner, 0 for draw + self.done = 0 # if the game is done or not + self.turn = 0 # the current turn, maximum is 9 + + def input_is_valid(self, move): + """ + Check if the move one of the the players want to take is a valid move in + the current situation. + + The move is valid if there is nothing where the player want to draw and + if the move is inside the grid. + """ + + # todo catch error if out of bounds + unraveled_index = np.unravel_index(move, self.size) + + if self.board[unraveled_index] == 0: + return True + else: + return False + + def update_board(self, move): + """ + Update board if move is valid + """ + + if self.input_is_valid(move): + unraveled_index = np.unravel_index(move, self.size) + self.board[unraveled_index] = self.player + + # change player + self.player = 3 - self.player + + # update turn counter + self.turn += 1 + + def check_win(self): + """ + Check if the current board is in a winning position. + Return 0 if it is not and the player id (1 or 2) if it is. + """ + + # impossible to win in less than 5 moves + if self.turn < 5: + return 0 + + # check rows + rows = self.board + for r in rows: + if (r == 1).all(): + self.winner = 1 + if (r == 2).all(): + self.winner = 2 + + # check columns + columns = self.board.T + for c in columns: + if (c == 1).all(): + self.winner = 1 + if (c == 2).all(): + self.winner = 2 + + # check diagonals + diagonals = [self.board.diagonal(), np.fliplr(self.board).diagonal()] + for d in diagonals: + if (d == 1).all(): + self.winner = 1 + if (d == 2).all(): + self.winner = 2 + + # handle draw + if self.turn == 9: + self.done = 1 + + # if someone won + if self.winner != 0: + self.done = 1 + + # if no winning conidition has been found + return self.winner + + def display_board(self): + """ + Display a humanly readable board. + + Example: + +---+---+---+ + | X | O | | + +---+---+---+ + | O | X | | + +---+---+---+ + | X | | O | + +---+---+---+ + """ + + line_decorator = "+---+---+---+" + board_decorator = "| {} | {} | {} |" + + convert = {0: " ", 1: "X", 2: "O"} + + for r in self.board: + print(line_decorator) + + L = [] + for i in r: + L.append(convert[i]) + + print(board_decorator.format(*L)) + + print(line_decorator) diff --git a/generate_board_hash_list.py b/generate_board_hash_list.py new file mode 100644 index 0000000..49f57cd --- /dev/null +++ b/generate_board_hash_list.py @@ -0,0 +1,74 @@ +from game import TicTacToe +from tqdm import tqdm +import numpy as np +from random import randint +from hashlib import sha1 +import pickle + +board_list = [] + +for k in tqdm(range(10000)): + #tqdm.write(f"Game {k}") + ttt = TicTacToe() + #tqdm.write(str(len(board_list))) + while not ttt.done: + #tqdm.write(f"turn: {ttt.turn}") + #ttt.display_board() + # compute all the symetries + b1 = ttt.board + b2 = np.rot90(b1, 1) + b3 = np.rot90(b1, 2) + b4 = np.rot90(b1, 3) + b5 = np.fliplr(b1) + b6 = np.flipud(b1) + b7 = b1.T # mirror diagonally + b8 = np.fliplr(b2) # mirror anti-diagonally + + # compute all the hash of the symetries + list_hd = [] + for b in [b1, b2, b3, b4, b5, b6, b7, b8]: + #list_hd.append() + hd = sha1(np.ascontiguousarray(b)).hexdigest() + if hd in board_list: + break + if hd not in board_list: + board_list.append(hd) + + # choose randomly + flag = False + while flag is not True: + move = randint(1,9) - 1 + flag = ttt.input_is_valid(move) + + # choose from the list of available moves + #zeros = np.where(ttt.board == 0) + #a = len(zeros[0]) + #i = randint(0, a-1) + #move = np.ravel_multi_index((zeros[0][i], zeros[1][i]), (3,3)) + # not faster than the random method above + # the method above is easier to understand + + #tqdm.write(str(move)) + ttt.update_board(move) + ttt.winner = ttt.check_win() + + # premature ending as soon as the 627 possibilites have been found + if len(board_list) == 627: + tqdm.write(f"breaking at {k}") + break + +# number of non-ending boards (sorted by number of turns) +# {0: 1, 1: 3, 2: 12, 3: 38, 4: 108, 5: 153, 6: 183, 7: 95, 8: 34, 9: 0} + +# number of all board (sorted by number of turns) +# {0: 1, 1: 3, 2: 12, 3: 38, 4: 108, 5: 174, 6: 204, 7: 153, 8: 57, 9: 15} + +# (it should always be 627) +print(f"Number of different (non-ending) boards: {len(board_list)}") + +# Dump the pickle obj +fd = open("board_hash_list.pkl", "wb") +pickle.dump(board_list, fd) +fd.close() + + diff --git a/q_learning.py b/q_learning.py new file mode 100644 index 0000000..4240d47 --- /dev/null +++ b/q_learning.py @@ -0,0 +1,207 @@ +from game import TicTacToe +from tqdm import tqdm +import numpy as np +from random import randint, uniform +from hashlib import sha1 +import pickle +import matplotlib.pyplot as plt + +# Retrive the board_set from the pickled data +fd = open("pickle_board_set.pkl", "rb") +board_set = pickle.load(fd) +fd.close() + +# Make it a list +board_list = list(board_set) + +def state_index_from_board(board): + """ + Return the Q-table index of a given board. + + For any board represented by a 3x3 ndarray, return the index in the Q-table + of the corresponding state. Takes into account the 8 symetries. + + There are 627 different non-ending states in the Q-table. + """ + + # Compute all the symetries + b1 = board + b2 = np.rot90(b1, 1) + b3 = np.rot90(b1, 2) + b4 = np.rot90(b1, 3) + b5 = np.fliplr(b1) + b6 = np.flipud(b1) + b7 = b1.T # mirror diagonally + b8 = np.fliplr(b2) # mirror anti-diagonally + + # Compute the hash of the symetries and find the index + for b in [b1, b2, b3, b4, b5, b6, b7, b8]: + hd = sha1(np.ascontiguousarray(b)).hexdigest() + if hd in board_list: + state_index = board_list.index(hd) + ttt.board = b # rotate the board to the canonical value + + return state_index + +def exploitation(ttt, Q): + """ + Exploit the Q-table. + + Fully exploit the knowledge from the Q-table to retrive the next action to + take in the given state. + + Parameters + ========== + ttt: a game, i.e. an instance of the TicTacToe class + Q: the Q-table + """ + + state_index = state_index_from_board(ttt.board) + + if ttt.player == 1: + a = np.argmax(Q[state_index, :]) + else: + a = np.argmin(Q[state_index, :]) + + return a + + +def random_agent(ttt): + flag = False + while flag is not True: + move = randint(1,9) - 1 + flag = ttt.input_is_valid(move) + return move + +# initialisation of the Q-table +state_size = 627 # harcoded, calculated from experiment +action_size = 9 # the 9 possible moves at each turn +Q = np.zeros((state_size, action_size)) + +# hyper parameters +num_episodes = 10000 +max_test = 9 # max number of steps per episode +alpha = 1 # learning rate (problem is deterministic) + +# exploration / exploitation parameters +epsilon = 1 +max_epsilon = 1 +min_epsilon = 0.01 +decay_rate = 0.001 + +winner_history = [] +reward_history = [] +exploit_history = [] +epsilon_history = [] + +for k in tqdm(range(num_episodes)): + + list_state_move = [] + + # reset the board + ttt = TicTacToe() + + while not ttt.done: + #ttt.display_board() + # get state and rotate the board in the canonical way + state_index = state_index_from_board(ttt.board) + #print(f"state_index = {state_index}") + #print(Q[state_index]) + + # explore or exploit? + tradeoff = uniform(0,1) + + do_exploit = tradeoff > epsilon + # debug: never exploit + do_exploit = False + + if do_exploit: + # exploit + move = exploitation(ttt, Q) + if not ttt.input_is_valid(move): + move = random_agent(ttt) + + else: + # Random Agent (exploration) + move = random_agent(ttt) + + # remember the couples (s,a) to give reward at the end of the episode + list_state_move.append((state_index, move)) + + ttt.update_board(move) + ttt.winner = ttt.check_win() + + + if k % 1000 == 0: + tqdm.write(f"epsiode: {k}, epsilon: {epsilon:0.3f}, winner: {ttt.winner}, \ +exploited: {tradeoff > epsilon}") + + # reward shaping + if ttt.winner == 1: + r = 1 + elif ttt.winner == 2: + r = -1 + else: + r = 0 # draw + + #print(r) + reward_history.append(r) + + # Update Q-table (not yet uising the Bellman equation, case is too simple) + for s, a in list_state_move: + Q[s,a] += alpha * r + r *= -1 # inverse the reward for the next move due to player inversion + + # Update the epsilon-greedy (decreasing) strategy + epsilon = min_epsilon + (max_epsilon - min_epsilon)*np.exp(-decay_rate*k) + + # remember stats for history + winner_history.append(ttt.winner) + epsilon_history.append(epsilon) + exploit_history.append(tradeoff > epsilon) + # input() + #import code + #code.interact(local=locals()) + +# Helper functions + +def player_move(ttt): + ttt.display_board() + state_index_from_board(ttt.board) + ttt.display_board() + print("Human's turn") + + flag = False + while flag is not True: + move = int(input("What is your move? [1-9] ")) + move -= 1 # range is 0-8 in array + flag = ttt.input_is_valid(move) + + ttt.update_board(move) + ttt.winner = ttt.check_win() + ttt.display_board() + +def ai_move(ttt, Q): + ttt.display_board() + print("AI's turn") + + move = exploitation(ttt, Q) + + ttt.input_is_valid(move) + ttt.update_board(move) + ttt.winner = ttt.check_win() + ttt.display_board() + + + +# plot graph +cumulative_win = np.cumsum(reward_history) +plt.plot(cumulative_win) +plt.title("Cumulative reward") +plt.xlabel("epochs") +plt.ylabel("cumsum of rewards") +plt.show() + +plt.plot(epsilon_history) +plt.title("Epsilon history") +plt.show()