tic-tac-toe-rl/q_learning.py

208 lines
5.1 KiB
Python
Raw Permalink Normal View History

2021-05-17 00:50:03 +02:00
from game import TicTacToe
from tqdm import tqdm
import numpy as np
from random import randint, uniform
from hashlib import sha1
import pickle
import matplotlib.pyplot as plt
# Retrive the board_set from the pickled data
fd = open("pickle_board_set.pkl", "rb")
board_set = pickle.load(fd)
fd.close()
# Make it a list
board_list = list(board_set)
def state_index_from_board(board):
"""
Return the Q-table index of a given board.
For any board represented by a 3x3 ndarray, return the index in the Q-table
of the corresponding state. Takes into account the 8 symetries.
There are 627 different non-ending states in the Q-table.
"""
# Compute all the symetries
b1 = board
b2 = np.rot90(b1, 1)
b3 = np.rot90(b1, 2)
b4 = np.rot90(b1, 3)
b5 = np.fliplr(b1)
b6 = np.flipud(b1)
b7 = b1.T # mirror diagonally
b8 = np.fliplr(b2) # mirror anti-diagonally
# Compute the hash of the symetries and find the index
for b in [b1, b2, b3, b4, b5, b6, b7, b8]:
hd = sha1(np.ascontiguousarray(b)).hexdigest()
if hd in board_list:
state_index = board_list.index(hd)
ttt.board = b # rotate the board to the canonical value
return state_index
def exploitation(ttt, Q):
"""
Exploit the Q-table.
Fully exploit the knowledge from the Q-table to retrive the next action to
take in the given state.
Parameters
==========
ttt: a game, i.e. an instance of the TicTacToe class
Q: the Q-table
"""
state_index = state_index_from_board(ttt.board)
if ttt.player == 1:
a = np.argmax(Q[state_index, :])
else:
a = np.argmin(Q[state_index, :])
return a
def random_agent(ttt):
flag = False
while flag is not True:
move = randint(1,9) - 1
flag = ttt.input_is_valid(move)
return move
# initialisation of the Q-table
state_size = 627 # harcoded, calculated from experiment
action_size = 9 # the 9 possible moves at each turn
Q = np.zeros((state_size, action_size))
# hyper parameters
num_episodes = 10000
max_test = 9 # max number of steps per episode
alpha = 1 # learning rate (problem is deterministic)
# exploration / exploitation parameters
epsilon = 1
max_epsilon = 1
min_epsilon = 0.01
decay_rate = 0.001
winner_history = []
reward_history = []
exploit_history = []
epsilon_history = []
for k in tqdm(range(num_episodes)):
list_state_move = []
# reset the board
ttt = TicTacToe()
while not ttt.done:
#ttt.display_board()
# get state and rotate the board in the canonical way
state_index = state_index_from_board(ttt.board)
#print(f"state_index = {state_index}")
#print(Q[state_index])
# explore or exploit?
tradeoff = uniform(0,1)
do_exploit = tradeoff > epsilon
# debug: never exploit
do_exploit = False
if do_exploit:
# exploit
move = exploitation(ttt, Q)
if not ttt.input_is_valid(move):
move = random_agent(ttt)
else:
# Random Agent (exploration)
move = random_agent(ttt)
# remember the couples (s,a) to give reward at the end of the episode
list_state_move.append((state_index, move))
ttt.update_board(move)
ttt.winner = ttt.check_win()
if k % 1000 == 0:
tqdm.write(f"epsiode: {k}, epsilon: {epsilon:0.3f}, winner: {ttt.winner}, \
exploited: {tradeoff > epsilon}")
# reward shaping
if ttt.winner == 1:
r = 1
elif ttt.winner == 2:
r = -1
else:
r = 0 # draw
#print(r)
reward_history.append(r)
# Update Q-table (not yet uising the Bellman equation, case is too simple)
for s, a in list_state_move:
Q[s,a] += alpha * r
r *= -1 # inverse the reward for the next move due to player inversion
# Update the epsilon-greedy (decreasing) strategy
epsilon = min_epsilon + (max_epsilon - min_epsilon)*np.exp(-decay_rate*k)
# remember stats for history
winner_history.append(ttt.winner)
epsilon_history.append(epsilon)
exploit_history.append(tradeoff > epsilon)
# input()
#import code
#code.interact(local=locals())
# Helper functions
def player_move(ttt):
ttt.display_board()
state_index_from_board(ttt.board)
ttt.display_board()
print("Human's turn")
flag = False
while flag is not True:
move = int(input("What is your move? [1-9] "))
move -= 1 # range is 0-8 in array
flag = ttt.input_is_valid(move)
ttt.update_board(move)
ttt.winner = ttt.check_win()
ttt.display_board()
def ai_move(ttt, Q):
ttt.display_board()
print("AI's turn")
move = exploitation(ttt, Q)
ttt.input_is_valid(move)
ttt.update_board(move)
ttt.winner = ttt.check_win()
ttt.display_board()
# plot graph
cumulative_win = np.cumsum(reward_history)
plt.plot(cumulative_win)
plt.title("Cumulative reward")
plt.xlabel("epochs")
plt.ylabel("cumsum of rewards")
plt.show()
plt.plot(epsilon_history)
plt.title("Epsilon history")
plt.show()