tic-tac-toe-rl/q_learning.py
2021-05-17 00:50:03 +02:00

207 lines
5.1 KiB
Python

from game import TicTacToe
from tqdm import tqdm
import numpy as np
from random import randint, uniform
from hashlib import sha1
import pickle
import matplotlib.pyplot as plt
# Retrive the board_set from the pickled data
fd = open("pickle_board_set.pkl", "rb")
board_set = pickle.load(fd)
fd.close()
# Make it a list
board_list = list(board_set)
def state_index_from_board(board):
"""
Return the Q-table index of a given board.
For any board represented by a 3x3 ndarray, return the index in the Q-table
of the corresponding state. Takes into account the 8 symetries.
There are 627 different non-ending states in the Q-table.
"""
# Compute all the symetries
b1 = board
b2 = np.rot90(b1, 1)
b3 = np.rot90(b1, 2)
b4 = np.rot90(b1, 3)
b5 = np.fliplr(b1)
b6 = np.flipud(b1)
b7 = b1.T # mirror diagonally
b8 = np.fliplr(b2) # mirror anti-diagonally
# Compute the hash of the symetries and find the index
for b in [b1, b2, b3, b4, b5, b6, b7, b8]:
hd = sha1(np.ascontiguousarray(b)).hexdigest()
if hd in board_list:
state_index = board_list.index(hd)
ttt.board = b # rotate the board to the canonical value
return state_index
def exploitation(ttt, Q):
"""
Exploit the Q-table.
Fully exploit the knowledge from the Q-table to retrive the next action to
take in the given state.
Parameters
==========
ttt: a game, i.e. an instance of the TicTacToe class
Q: the Q-table
"""
state_index = state_index_from_board(ttt.board)
if ttt.player == 1:
a = np.argmax(Q[state_index, :])
else:
a = np.argmin(Q[state_index, :])
return a
def random_agent(ttt):
flag = False
while flag is not True:
move = randint(1,9) - 1
flag = ttt.input_is_valid(move)
return move
# initialisation of the Q-table
state_size = 627 # harcoded, calculated from experiment
action_size = 9 # the 9 possible moves at each turn
Q = np.zeros((state_size, action_size))
# hyper parameters
num_episodes = 10000
max_test = 9 # max number of steps per episode
alpha = 1 # learning rate (problem is deterministic)
# exploration / exploitation parameters
epsilon = 1
max_epsilon = 1
min_epsilon = 0.01
decay_rate = 0.001
winner_history = []
reward_history = []
exploit_history = []
epsilon_history = []
for k in tqdm(range(num_episodes)):
list_state_move = []
# reset the board
ttt = TicTacToe()
while not ttt.done:
#ttt.display_board()
# get state and rotate the board in the canonical way
state_index = state_index_from_board(ttt.board)
#print(f"state_index = {state_index}")
#print(Q[state_index])
# explore or exploit?
tradeoff = uniform(0,1)
do_exploit = tradeoff > epsilon
# debug: never exploit
do_exploit = False
if do_exploit:
# exploit
move = exploitation(ttt, Q)
if not ttt.input_is_valid(move):
move = random_agent(ttt)
else:
# Random Agent (exploration)
move = random_agent(ttt)
# remember the couples (s,a) to give reward at the end of the episode
list_state_move.append((state_index, move))
ttt.update_board(move)
ttt.winner = ttt.check_win()
if k % 1000 == 0:
tqdm.write(f"epsiode: {k}, epsilon: {epsilon:0.3f}, winner: {ttt.winner}, \
exploited: {tradeoff > epsilon}")
# reward shaping
if ttt.winner == 1:
r = 1
elif ttt.winner == 2:
r = -1
else:
r = 0 # draw
#print(r)
reward_history.append(r)
# Update Q-table (not yet uising the Bellman equation, case is too simple)
for s, a in list_state_move:
Q[s,a] += alpha * r
r *= -1 # inverse the reward for the next move due to player inversion
# Update the epsilon-greedy (decreasing) strategy
epsilon = min_epsilon + (max_epsilon - min_epsilon)*np.exp(-decay_rate*k)
# remember stats for history
winner_history.append(ttt.winner)
epsilon_history.append(epsilon)
exploit_history.append(tradeoff > epsilon)
# input()
#import code
#code.interact(local=locals())
# Helper functions
def player_move(ttt):
ttt.display_board()
state_index_from_board(ttt.board)
ttt.display_board()
print("Human's turn")
flag = False
while flag is not True:
move = int(input("What is your move? [1-9] "))
move -= 1 # range is 0-8 in array
flag = ttt.input_is_valid(move)
ttt.update_board(move)
ttt.winner = ttt.check_win()
ttt.display_board()
def ai_move(ttt, Q):
ttt.display_board()
print("AI's turn")
move = exploitation(ttt, Q)
ttt.input_is_valid(move)
ttt.update_board(move)
ttt.winner = ttt.check_win()
ttt.display_board()
# plot graph
cumulative_win = np.cumsum(reward_history)
plt.plot(cumulative_win)
plt.title("Cumulative reward")
plt.xlabel("epochs")
plt.ylabel("cumsum of rewards")
plt.show()
plt.plot(epsilon_history)
plt.title("Epsilon history")
plt.show()