You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
138 lines
4.6 KiB
Python
138 lines
4.6 KiB
Python
import datetime
|
|
import os
|
|
import socket
|
|
import threading
|
|
import time
|
|
import queue
|
|
import numpy as np
|
|
import tensorflow as tf
|
|
from tf_agents.agents.reinforce import reinforce_agent
|
|
from tf_agents.environments import py_environment, tf_py_environment
|
|
from tf_agents.networks import actor_distribution_network
|
|
from tf_agents.policies import policy_saver
|
|
from tf_agents.replay_buffers import tf_uniform_replay_buffer
|
|
from tf_agents.trajectories import trajectory
|
|
from tf_agents.utils import common
|
|
|
|
from coapthon.client.helperclient import HelperClient
|
|
from coapthon.client.superviseur import (SuperviseurGlobal,
|
|
SuperviseurLocalFiltre)
|
|
from coapthon.utils import parse_uri
|
|
from utils_learning import MaquetteCoapEnv
|
|
|
|
fc_layer_params = (30,)
|
|
replay_buffer_capacity = 1500
|
|
learning_rate = 0.05
|
|
|
|
n_capteur = 25
|
|
n_superviseur = 6
|
|
|
|
tempdir = "save_run_{}-{}-{}".format(datetime.datetime.now().date(
|
|
), datetime.datetime.now().hour, datetime.datetime.now().minute)
|
|
|
|
host, port, path = parse_uri("coap://raspberrypi.local/basic")
|
|
try:
|
|
tmp = socket.gethostbyname(host)
|
|
host = tmp
|
|
except socket.gaierror:
|
|
pass
|
|
|
|
eval_env_py = MaquetteCoapEnv([HelperClient(server=(host, port)) for _ in range(n_capteur)],
|
|
SuperviseurLocalFiltre,
|
|
SuperviseurGlobal,
|
|
path)
|
|
maquettes_py = [MaquetteCoapEnv([HelperClient(server=(host, port)) for _ in range(n_capteur)],
|
|
SuperviseurLocalFiltre,
|
|
SuperviseurGlobal,
|
|
path)
|
|
for _ in range(n_superviseur)]
|
|
|
|
maquettes = [tf_py_environment.TFPyEnvironment(
|
|
maquette) for maquette in maquettes_py]
|
|
eval_env = tf_py_environment.TFPyEnvironment(eval_env_py)
|
|
|
|
actor_net = actor_distribution_network.ActorDistributionNetwork(
|
|
eval_env.observation_spec(),
|
|
eval_env.action_spec(),
|
|
fc_layer_params=fc_layer_params)
|
|
|
|
# optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
|
|
optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)
|
|
global_step = tf.compat.v1.train.get_or_create_global_step()
|
|
|
|
train_step_counter = tf.compat.v2.Variable(0)
|
|
|
|
tf_agent = reinforce_agent.ReinforceAgent(
|
|
eval_env.time_step_spec(),
|
|
eval_env.action_spec(),
|
|
actor_network=actor_net,
|
|
optimizer=optimizer,
|
|
normalize_returns=True,
|
|
train_step_counter=train_step_counter)
|
|
tf_agent.initialize()
|
|
|
|
collect_policy = tf_agent.collect_policy # Avec exploration
|
|
eval_policy = tf_agent.policy # Sans exploration
|
|
|
|
replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
|
|
data_spec=tf_agent.collect_data_spec,
|
|
batch_size=eval_env.batch_size,
|
|
max_length=replay_buffer_capacity)
|
|
|
|
buffer_lock = threading.Lock()
|
|
|
|
|
|
def collecteur(maquette, policy):
|
|
# queue_commande = queue.Queue()
|
|
time_step = maquette.step(np.array(n_capteur*[0], dtype=np.float32))
|
|
while True:
|
|
# if queue_commande.empty():
|
|
# pass
|
|
# commande = queue.get()
|
|
|
|
action_step = policy.action(time_step)
|
|
next_time_step = maquette.step(action_step.action)
|
|
traj = trajectory.from_transition(
|
|
time_step, action_step, next_time_step)
|
|
with buffer_lock:
|
|
replay_buffer.add_batch(traj)
|
|
time_step = next_time_step
|
|
if traj.is_boundary():
|
|
maquette.reset()
|
|
|
|
|
|
checkpoint_dir = os.path.join(tempdir, 'checkpoint')
|
|
train_checkpointer = common.Checkpointer(
|
|
ckpt_dir=checkpoint_dir,
|
|
max_to_keep=1,
|
|
agent=tf_agent,
|
|
policy=tf_agent.policy,
|
|
replay_buffer=replay_buffer,
|
|
global_step=global_step
|
|
)
|
|
policy_dir = os.path.join(tempdir, 'policy')
|
|
tf_policy_saver = policy_saver.PolicySaver(tf_agent.policy)
|
|
|
|
threads_collecteur = [threading.Thread(target=collecteur,
|
|
name="Collecteur {}".format(n),
|
|
args=(maquette, collect_policy))
|
|
for n, maquette in enumerate(maquettes)]
|
|
|
|
[thread.start() for thread in threads_collecteur]
|
|
|
|
while True:
|
|
time.sleep(60)
|
|
with buffer_lock:
|
|
if replay_buffer.num_frames() >= 200:
|
|
experience = replay_buffer.gather_all()
|
|
experience = replay_buffer.gather_all()
|
|
train_loss = tf_agent.train(experience)
|
|
replay_buffer.clear()
|
|
tf_policy_saver.save(policy_dir)
|
|
print("thread : {}\tBuffer :{}".format(
|
|
threading.active_count(), replay_buffer.num_frames()))
|
|
try:
|
|
train_checkpointer.save(global_step)
|
|
except Exception:
|
|
pass
|