from game import TicTacToe from tqdm import tqdm import numpy as np from random import randint, uniform from hashlib import sha1 import pickle import matplotlib.pyplot as plt # Retrive the board_set from the pickled data fd = open("pickle_board_set.pkl", "rb") board_set = pickle.load(fd) fd.close() # Make it a list board_list = list(board_set) def state_index_from_board(board): """ Return the Q-table index of a given board. For any board represented by a 3x3 ndarray, return the index in the Q-table of the corresponding state. Takes into account the 8 symetries. There are 627 different non-ending states in the Q-table. """ # Compute all the symetries b1 = board b2 = np.rot90(b1, 1) b3 = np.rot90(b1, 2) b4 = np.rot90(b1, 3) b5 = np.fliplr(b1) b6 = np.flipud(b1) b7 = b1.T # mirror diagonally b8 = np.fliplr(b2) # mirror anti-diagonally # Compute the hash of the symetries and find the index for b in [b1, b2, b3, b4, b5, b6, b7, b8]: hd = sha1(np.ascontiguousarray(b)).hexdigest() if hd in board_list: state_index = board_list.index(hd) ttt.board = b # rotate the board to the canonical value return state_index def exploitation(ttt, Q): """ Exploit the Q-table. Fully exploit the knowledge from the Q-table to retrive the next action to take in the given state. Parameters ========== ttt: a game, i.e. an instance of the TicTacToe class Q: the Q-table """ state_index = state_index_from_board(ttt.board) if ttt.player == 1: a = np.argmax(Q[state_index, :]) else: a = np.argmin(Q[state_index, :]) return a def random_agent(ttt): flag = False while flag is not True: move = randint(1,9) - 1 flag = ttt.input_is_valid(move) return move # initialisation of the Q-table state_size = 627 # harcoded, calculated from experiment action_size = 9 # the 9 possible moves at each turn Q = np.zeros((state_size, action_size)) # hyper parameters num_episodes = 10000 max_test = 9 # max number of steps per episode alpha = 1 # learning rate (problem is deterministic) # exploration / exploitation parameters epsilon = 1 max_epsilon = 1 min_epsilon = 0.01 decay_rate = 0.001 winner_history = [] reward_history = [] exploit_history = [] epsilon_history = [] for k in tqdm(range(num_episodes)): list_state_move = [] # reset the board ttt = TicTacToe() while not ttt.done: #ttt.display_board() # get state and rotate the board in the canonical way state_index = state_index_from_board(ttt.board) #print(f"state_index = {state_index}") #print(Q[state_index]) # explore or exploit? tradeoff = uniform(0,1) do_exploit = tradeoff > epsilon # debug: never exploit do_exploit = False if do_exploit: # exploit move = exploitation(ttt, Q) if not ttt.input_is_valid(move): move = random_agent(ttt) else: # Random Agent (exploration) move = random_agent(ttt) # remember the couples (s,a) to give reward at the end of the episode list_state_move.append((state_index, move)) ttt.update_board(move) ttt.winner = ttt.check_win() if k % 1000 == 0: tqdm.write(f"epsiode: {k}, epsilon: {epsilon:0.3f}, winner: {ttt.winner}, \ exploited: {tradeoff > epsilon}") # reward shaping if ttt.winner == 1: r = 1 elif ttt.winner == 2: r = -1 else: r = 0 # draw #print(r) reward_history.append(r) # Update Q-table (not yet uising the Bellman equation, case is too simple) for s, a in list_state_move: Q[s,a] += alpha * r r *= -1 # inverse the reward for the next move due to player inversion # Update the epsilon-greedy (decreasing) strategy epsilon = min_epsilon + (max_epsilon - min_epsilon)*np.exp(-decay_rate*k) # remember stats for history winner_history.append(ttt.winner) epsilon_history.append(epsilon) exploit_history.append(tradeoff > epsilon) # input() #import code #code.interact(local=locals()) # Helper functions def player_move(ttt): ttt.display_board() state_index_from_board(ttt.board) ttt.display_board() print("Human's turn") flag = False while flag is not True: move = int(input("What is your move? [1-9] ")) move -= 1 # range is 0-8 in array flag = ttt.input_is_valid(move) ttt.update_board(move) ttt.winner = ttt.check_win() ttt.display_board() def ai_move(ttt, Q): ttt.display_board() print("AI's turn") move = exploitation(ttt, Q) ttt.input_is_valid(move) ttt.update_board(move) ttt.winner = ttt.check_win() ttt.display_board() # plot graph cumulative_win = np.cumsum(reward_history) plt.plot(cumulative_win) plt.title("Cumulative reward") plt.xlabel("epochs") plt.ylabel("cumsum of rewards") plt.show() plt.plot(epsilon_history) plt.title("Epsilon history") plt.show()