207 lines
5.1 KiB
Python
207 lines
5.1 KiB
Python
from game import TicTacToe
|
|
from tqdm import tqdm
|
|
import numpy as np
|
|
from random import randint, uniform
|
|
from hashlib import sha1
|
|
import pickle
|
|
import matplotlib.pyplot as plt
|
|
|
|
# Retrive the board_set from the pickled data
|
|
fd = open("pickle_board_set.pkl", "rb")
|
|
board_set = pickle.load(fd)
|
|
fd.close()
|
|
|
|
# Make it a list
|
|
board_list = list(board_set)
|
|
|
|
def state_index_from_board(board):
|
|
"""
|
|
Return the Q-table index of a given board.
|
|
|
|
For any board represented by a 3x3 ndarray, return the index in the Q-table
|
|
of the corresponding state. Takes into account the 8 symetries.
|
|
|
|
There are 627 different non-ending states in the Q-table.
|
|
"""
|
|
|
|
# Compute all the symetries
|
|
b1 = board
|
|
b2 = np.rot90(b1, 1)
|
|
b3 = np.rot90(b1, 2)
|
|
b4 = np.rot90(b1, 3)
|
|
b5 = np.fliplr(b1)
|
|
b6 = np.flipud(b1)
|
|
b7 = b1.T # mirror diagonally
|
|
b8 = np.fliplr(b2) # mirror anti-diagonally
|
|
|
|
# Compute the hash of the symetries and find the index
|
|
for b in [b1, b2, b3, b4, b5, b6, b7, b8]:
|
|
hd = sha1(np.ascontiguousarray(b)).hexdigest()
|
|
if hd in board_list:
|
|
state_index = board_list.index(hd)
|
|
ttt.board = b # rotate the board to the canonical value
|
|
|
|
return state_index
|
|
|
|
def exploitation(ttt, Q):
|
|
"""
|
|
Exploit the Q-table.
|
|
|
|
Fully exploit the knowledge from the Q-table to retrive the next action to
|
|
take in the given state.
|
|
|
|
Parameters
|
|
==========
|
|
ttt: a game, i.e. an instance of the TicTacToe class
|
|
Q: the Q-table
|
|
"""
|
|
|
|
state_index = state_index_from_board(ttt.board)
|
|
|
|
if ttt.player == 1:
|
|
a = np.argmax(Q[state_index, :])
|
|
else:
|
|
a = np.argmin(Q[state_index, :])
|
|
|
|
return a
|
|
|
|
|
|
def random_agent(ttt):
|
|
flag = False
|
|
while flag is not True:
|
|
move = randint(1,9) - 1
|
|
flag = ttt.input_is_valid(move)
|
|
return move
|
|
|
|
# initialisation of the Q-table
|
|
state_size = 627 # harcoded, calculated from experiment
|
|
action_size = 9 # the 9 possible moves at each turn
|
|
Q = np.zeros((state_size, action_size))
|
|
|
|
# hyper parameters
|
|
num_episodes = 10000
|
|
max_test = 9 # max number of steps per episode
|
|
alpha = 1 # learning rate (problem is deterministic)
|
|
|
|
# exploration / exploitation parameters
|
|
epsilon = 1
|
|
max_epsilon = 1
|
|
min_epsilon = 0.01
|
|
decay_rate = 0.001
|
|
|
|
winner_history = []
|
|
reward_history = []
|
|
exploit_history = []
|
|
epsilon_history = []
|
|
|
|
for k in tqdm(range(num_episodes)):
|
|
|
|
list_state_move = []
|
|
|
|
# reset the board
|
|
ttt = TicTacToe()
|
|
|
|
while not ttt.done:
|
|
#ttt.display_board()
|
|
# get state and rotate the board in the canonical way
|
|
state_index = state_index_from_board(ttt.board)
|
|
#print(f"state_index = {state_index}")
|
|
#print(Q[state_index])
|
|
|
|
# explore or exploit?
|
|
tradeoff = uniform(0,1)
|
|
|
|
do_exploit = tradeoff > epsilon
|
|
# debug: never exploit
|
|
do_exploit = False
|
|
|
|
if do_exploit:
|
|
# exploit
|
|
move = exploitation(ttt, Q)
|
|
if not ttt.input_is_valid(move):
|
|
move = random_agent(ttt)
|
|
|
|
else:
|
|
# Random Agent (exploration)
|
|
move = random_agent(ttt)
|
|
|
|
# remember the couples (s,a) to give reward at the end of the episode
|
|
list_state_move.append((state_index, move))
|
|
|
|
ttt.update_board(move)
|
|
ttt.winner = ttt.check_win()
|
|
|
|
|
|
if k % 1000 == 0:
|
|
tqdm.write(f"epsiode: {k}, epsilon: {epsilon:0.3f}, winner: {ttt.winner}, \
|
|
exploited: {tradeoff > epsilon}")
|
|
|
|
# reward shaping
|
|
if ttt.winner == 1:
|
|
r = 1
|
|
elif ttt.winner == 2:
|
|
r = -1
|
|
else:
|
|
r = 0 # draw
|
|
|
|
#print(r)
|
|
reward_history.append(r)
|
|
|
|
# Update Q-table (not yet uising the Bellman equation, case is too simple)
|
|
for s, a in list_state_move:
|
|
Q[s,a] += alpha * r
|
|
r *= -1 # inverse the reward for the next move due to player inversion
|
|
|
|
# Update the epsilon-greedy (decreasing) strategy
|
|
epsilon = min_epsilon + (max_epsilon - min_epsilon)*np.exp(-decay_rate*k)
|
|
|
|
# remember stats for history
|
|
winner_history.append(ttt.winner)
|
|
epsilon_history.append(epsilon)
|
|
exploit_history.append(tradeoff > epsilon)
|
|
# input()
|
|
#import code
|
|
#code.interact(local=locals())
|
|
|
|
# Helper functions
|
|
|
|
def player_move(ttt):
|
|
ttt.display_board()
|
|
state_index_from_board(ttt.board)
|
|
ttt.display_board()
|
|
print("Human's turn")
|
|
|
|
flag = False
|
|
while flag is not True:
|
|
move = int(input("What is your move? [1-9] "))
|
|
move -= 1 # range is 0-8 in array
|
|
flag = ttt.input_is_valid(move)
|
|
|
|
ttt.update_board(move)
|
|
ttt.winner = ttt.check_win()
|
|
ttt.display_board()
|
|
|
|
def ai_move(ttt, Q):
|
|
ttt.display_board()
|
|
print("AI's turn")
|
|
|
|
move = exploitation(ttt, Q)
|
|
|
|
ttt.input_is_valid(move)
|
|
ttt.update_board(move)
|
|
ttt.winner = ttt.check_win()
|
|
ttt.display_board()
|
|
|
|
|
|
|
|
# plot graph
|
|
cumulative_win = np.cumsum(reward_history)
|
|
plt.plot(cumulative_win)
|
|
plt.title("Cumulative reward")
|
|
plt.xlabel("epochs")
|
|
plt.ylabel("cumsum of rewards")
|
|
plt.show()
|
|
|
|
plt.plot(epsilon_history)
|
|
plt.title("Epsilon history")
|
|
plt.show()
|