You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
CoAP/boucle.py

138 lines
4.6 KiB
Python

import datetime
import os
import socket
import threading
import time
import queue
import numpy as np
import tensorflow as tf
from tf_agents.agents.reinforce import reinforce_agent
from tf_agents.environments import py_environment, tf_py_environment
from tf_agents.networks import actor_distribution_network
from tf_agents.policies import policy_saver
from tf_agents.replay_buffers import tf_uniform_replay_buffer
from tf_agents.trajectories import trajectory
from tf_agents.utils import common
from coapthon.client.helperclient import HelperClient
from coapthon.client.superviseur import (SuperviseurGlobal,
SuperviseurLocalFiltre)
from coapthon.utils import parse_uri
from utils_learning import MaquetteCoapEnv
fc_layer_params = (30,)
replay_buffer_capacity = 1500
learning_rate = 0.05
n_capteur = 25
n_superviseur = 6
tempdir = "save_run_{}-{}-{}".format(datetime.datetime.now().date(
), datetime.datetime.now().hour, datetime.datetime.now().minute)
host, port, path = parse_uri("coap://raspberrypi.local/basic")
try:
tmp = socket.gethostbyname(host)
host = tmp
except socket.gaierror:
pass
eval_env_py = MaquetteCoapEnv([HelperClient(server=(host, port)) for _ in range(n_capteur)],
SuperviseurLocalFiltre,
SuperviseurGlobal,
path)
maquettes_py = [MaquetteCoapEnv([HelperClient(server=(host, port)) for _ in range(n_capteur)],
SuperviseurLocalFiltre,
SuperviseurGlobal,
path)
for _ in range(n_superviseur)]
maquettes = [tf_py_environment.TFPyEnvironment(
maquette) for maquette in maquettes_py]
eval_env = tf_py_environment.TFPyEnvironment(eval_env_py)
actor_net = actor_distribution_network.ActorDistributionNetwork(
eval_env.observation_spec(),
eval_env.action_spec(),
fc_layer_params=fc_layer_params)
# optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)
global_step = tf.compat.v1.train.get_or_create_global_step()
train_step_counter = tf.compat.v2.Variable(0)
tf_agent = reinforce_agent.ReinforceAgent(
eval_env.time_step_spec(),
eval_env.action_spec(),
actor_network=actor_net,
optimizer=optimizer,
normalize_returns=True,
train_step_counter=train_step_counter)
tf_agent.initialize()
collect_policy = tf_agent.collect_policy # Avec exploration
eval_policy = tf_agent.policy # Sans exploration
replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
data_spec=tf_agent.collect_data_spec,
batch_size=eval_env.batch_size,
max_length=replay_buffer_capacity)
buffer_lock = threading.Lock()
def collecteur(maquette, policy):
# queue_commande = queue.Queue()
time_step = maquette.step(np.array(n_capteur*[0], dtype=np.float32))
while True:
# if queue_commande.empty():
# pass
# commande = queue.get()
action_step = policy.action(time_step)
next_time_step = maquette.step(action_step.action)
traj = trajectory.from_transition(
time_step, action_step, next_time_step)
with buffer_lock:
replay_buffer.add_batch(traj)
time_step = next_time_step
if traj.is_boundary():
maquette.reset()
checkpoint_dir = os.path.join(tempdir, 'checkpoint')
train_checkpointer = common.Checkpointer(
ckpt_dir=checkpoint_dir,
max_to_keep=1,
agent=tf_agent,
policy=tf_agent.policy,
replay_buffer=replay_buffer,
global_step=global_step
)
policy_dir = os.path.join(tempdir, 'policy')
tf_policy_saver = policy_saver.PolicySaver(tf_agent.policy)
threads_collecteur = [threading.Thread(target=collecteur,
name="Collecteur {}".format(n),
args=(maquette, collect_policy))
for n, maquette in enumerate(maquettes)]
[thread.start() for thread in threads_collecteur]
while True:
time.sleep(60)
with buffer_lock:
if replay_buffer.num_frames() >= 200:
experience = replay_buffer.gather_all()
experience = replay_buffer.gather_all()
train_loss = tf_agent.train(experience)
replay_buffer.clear()
tf_policy_saver.save(policy_dir)
print("thread : {}\tBuffer :{}".format(
threading.active_count(), replay_buffer.num_frames()))
try:
train_checkpointer.save(global_step)
except Exception:
pass