Apprentissage
This commit is contained in:
parent
b235cdd151
commit
e329ee656b
8 changed files with 382 additions and 53 deletions
137
boucle.py
Normal file
137
boucle.py
Normal file
|
@ -0,0 +1,137 @@
|
||||||
|
import datetime
|
||||||
|
import os
|
||||||
|
import socket
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
import queue
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf
|
||||||
|
from tf_agents.agents.reinforce import reinforce_agent
|
||||||
|
from tf_agents.environments import py_environment, tf_py_environment
|
||||||
|
from tf_agents.networks import actor_distribution_network
|
||||||
|
from tf_agents.policies import policy_saver
|
||||||
|
from tf_agents.replay_buffers import tf_uniform_replay_buffer
|
||||||
|
from tf_agents.trajectories import trajectory
|
||||||
|
from tf_agents.utils import common
|
||||||
|
|
||||||
|
from coapthon.client.helperclient import HelperClient
|
||||||
|
from coapthon.client.superviseur import (SuperviseurGlobal,
|
||||||
|
SuperviseurLocalFiltre)
|
||||||
|
from coapthon.utils import parse_uri
|
||||||
|
from utils_learning import MaquetteCoapEnv
|
||||||
|
|
||||||
|
fc_layer_params = (30,)
|
||||||
|
replay_buffer_capacity = 1500
|
||||||
|
learning_rate = 0.05
|
||||||
|
|
||||||
|
n_capteur = 25
|
||||||
|
n_superviseur = 6
|
||||||
|
|
||||||
|
tempdir = "save_run_{}-{}-{}".format(datetime.datetime.now().date(
|
||||||
|
), datetime.datetime.now().hour, datetime.datetime.now().minute)
|
||||||
|
|
||||||
|
host, port, path = parse_uri("coap://raspberrypi.local/basic")
|
||||||
|
try:
|
||||||
|
tmp = socket.gethostbyname(host)
|
||||||
|
host = tmp
|
||||||
|
except socket.gaierror:
|
||||||
|
pass
|
||||||
|
|
||||||
|
eval_env_py = MaquetteCoapEnv([HelperClient(server=(host, port)) for _ in range(n_capteur)],
|
||||||
|
SuperviseurLocalFiltre,
|
||||||
|
SuperviseurGlobal,
|
||||||
|
path)
|
||||||
|
maquettes_py = [MaquetteCoapEnv([HelperClient(server=(host, port)) for _ in range(n_capteur)],
|
||||||
|
SuperviseurLocalFiltre,
|
||||||
|
SuperviseurGlobal,
|
||||||
|
path)
|
||||||
|
for _ in range(n_superviseur)]
|
||||||
|
|
||||||
|
maquettes = [tf_py_environment.TFPyEnvironment(
|
||||||
|
maquette) for maquette in maquettes_py]
|
||||||
|
eval_env = tf_py_environment.TFPyEnvironment(eval_env_py)
|
||||||
|
|
||||||
|
actor_net = actor_distribution_network.ActorDistributionNetwork(
|
||||||
|
eval_env.observation_spec(),
|
||||||
|
eval_env.action_spec(),
|
||||||
|
fc_layer_params=fc_layer_params)
|
||||||
|
|
||||||
|
# optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
|
||||||
|
optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)
|
||||||
|
global_step = tf.compat.v1.train.get_or_create_global_step()
|
||||||
|
|
||||||
|
train_step_counter = tf.compat.v2.Variable(0)
|
||||||
|
|
||||||
|
tf_agent = reinforce_agent.ReinforceAgent(
|
||||||
|
eval_env.time_step_spec(),
|
||||||
|
eval_env.action_spec(),
|
||||||
|
actor_network=actor_net,
|
||||||
|
optimizer=optimizer,
|
||||||
|
normalize_returns=True,
|
||||||
|
train_step_counter=train_step_counter)
|
||||||
|
tf_agent.initialize()
|
||||||
|
|
||||||
|
collect_policy = tf_agent.collect_policy # Avec exploration
|
||||||
|
eval_policy = tf_agent.policy # Sans exploration
|
||||||
|
|
||||||
|
replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
|
||||||
|
data_spec=tf_agent.collect_data_spec,
|
||||||
|
batch_size=eval_env.batch_size,
|
||||||
|
max_length=replay_buffer_capacity)
|
||||||
|
|
||||||
|
buffer_lock = threading.Lock()
|
||||||
|
|
||||||
|
|
||||||
|
def collecteur(maquette, policy):
|
||||||
|
# queue_commande = queue.Queue()
|
||||||
|
time_step = maquette.step(np.array(n_capteur*[0], dtype=np.float32))
|
||||||
|
while True:
|
||||||
|
# if queue_commande.empty():
|
||||||
|
# pass
|
||||||
|
# commande = queue.get()
|
||||||
|
|
||||||
|
action_step = policy.action(time_step)
|
||||||
|
next_time_step = maquette.step(action_step.action)
|
||||||
|
traj = trajectory.from_transition(
|
||||||
|
time_step, action_step, next_time_step)
|
||||||
|
with buffer_lock:
|
||||||
|
replay_buffer.add_batch(traj)
|
||||||
|
time_step = next_time_step
|
||||||
|
if traj.is_boundary():
|
||||||
|
maquette.reset()
|
||||||
|
|
||||||
|
|
||||||
|
checkpoint_dir = os.path.join(tempdir, 'checkpoint')
|
||||||
|
train_checkpointer = common.Checkpointer(
|
||||||
|
ckpt_dir=checkpoint_dir,
|
||||||
|
max_to_keep=1,
|
||||||
|
agent=tf_agent,
|
||||||
|
policy=tf_agent.policy,
|
||||||
|
replay_buffer=replay_buffer,
|
||||||
|
global_step=global_step
|
||||||
|
)
|
||||||
|
policy_dir = os.path.join(tempdir, 'policy')
|
||||||
|
tf_policy_saver = policy_saver.PolicySaver(tf_agent.policy)
|
||||||
|
|
||||||
|
threads_collecteur = [threading.Thread(target=collecteur,
|
||||||
|
name="Collecteur {}".format(n),
|
||||||
|
args=(maquette, collect_policy))
|
||||||
|
for n, maquette in enumerate(maquettes)]
|
||||||
|
|
||||||
|
[thread.start() for thread in threads_collecteur]
|
||||||
|
|
||||||
|
while True:
|
||||||
|
time.sleep(60)
|
||||||
|
with buffer_lock:
|
||||||
|
if replay_buffer.num_frames() >= 200:
|
||||||
|
experience = replay_buffer.gather_all()
|
||||||
|
experience = replay_buffer.gather_all()
|
||||||
|
train_loss = tf_agent.train(experience)
|
||||||
|
replay_buffer.clear()
|
||||||
|
tf_policy_saver.save(policy_dir)
|
||||||
|
print("thread : {}\tBuffer :{}".format(
|
||||||
|
threading.active_count(), replay_buffer.num_frames()))
|
||||||
|
try:
|
||||||
|
train_checkpointer.save(global_step)
|
||||||
|
except Exception:
|
||||||
|
pass
|
|
@ -233,6 +233,7 @@ class CoAP(object):
|
||||||
message.timeouted = False
|
message.timeouted = False
|
||||||
else:
|
else:
|
||||||
logger.warning("Give up on message {message}".format(message=message.line_print))
|
logger.warning("Give up on message {message}".format(message=message.line_print))
|
||||||
|
self.superviseur.failed_request()
|
||||||
message.timeouted = True
|
message.timeouted = True
|
||||||
|
|
||||||
# Inform the user, that nothing was received
|
# Inform the user, that nothing was received
|
||||||
|
|
|
@ -28,6 +28,9 @@ class SuperviseurLocalPlaceHolder():
|
||||||
self._taux_retransmition = 0
|
self._taux_retransmition = 0
|
||||||
self._RTO = defines.ACK_TIMEOUT
|
self._RTO = defines.ACK_TIMEOUT
|
||||||
|
|
||||||
|
def reset_rto(self):
|
||||||
|
self._RTO = defines.ACK_TIMEOUT
|
||||||
|
|
||||||
def envoie_message(self, message) -> None:
|
def envoie_message(self, message) -> None:
|
||||||
self.envoie_token(message.token)
|
self.envoie_token(message.token)
|
||||||
|
|
||||||
|
@ -168,46 +171,18 @@ class SuperviseurGlobal():
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def state(self):
|
def state(self):
|
||||||
"""[summary]
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
[type]: [description]
|
|
||||||
"""
|
|
||||||
|
|
||||||
"""
|
|
||||||
taux_retransmissions = np.array(
|
|
||||||
[superviseur.taux_retransmission for superviseur in self.superviseurs])
|
|
||||||
|
|
||||||
min_rtts = np.array(
|
|
||||||
[superviseur.min_RTT for superviseur in self.superviseurs])
|
|
||||||
avg_rtts = np.array(
|
|
||||||
[superviseur.avg_RTT for superviseur in self.superviseurs])
|
|
||||||
ratio_rtts = np.array(min_rtts/avg_rtts)
|
|
||||||
|
|
||||||
if isinstance(self.superviseurs[0], SuperviseurLocalFiltre):
|
|
||||||
rtt_ls = np.array(
|
|
||||||
[superviseur.RTT_L for superviseur in self.superviseurs])
|
|
||||||
rtt_ss = np.array(
|
|
||||||
[superviseur.RTT_S for superviseur in self.superviseurs])
|
|
||||||
ratio_filtres = rtt_ss/rtt_ls
|
|
||||||
|
|
||||||
representation_etat = np.array(
|
|
||||||
[taux_retransmissions, ratio_rtts, ratio_filtres])
|
|
||||||
|
|
||||||
else:
|
|
||||||
representation_etat = np.array([taux_retransmissions, ratio_rtts])
|
|
||||||
|
|
||||||
return representation_etat
|
|
||||||
"""
|
|
||||||
|
|
||||||
vecteurs = []
|
vecteurs = []
|
||||||
for n, superviseur in enumerate(self.superviseurs):
|
for n, superviseur in enumerate(self.superviseurs):
|
||||||
if isinstance(superviseur, SuperviseurLocalFiltre):
|
if isinstance(superviseur, SuperviseurLocalFiltre):
|
||||||
try :
|
try:
|
||||||
vecteurs.append(np.array([[superviseur.taux_retransmission, superviseur.min_RTT/superviseur.avg_RTT, superviseur.RTT_S/superviseur.RTT_L]], dtype=np.float32))
|
vecteurs.append(np.array([[superviseur.taux_retransmission, superviseur.min_RTT /
|
||||||
|
superviseur.avg_RTT, superviseur.RTT_S/superviseur.RTT_L]], dtype=np.float32))
|
||||||
except NoRttError:
|
except NoRttError:
|
||||||
vecteurs.append(self._last_state[:,n])
|
vecteurs.append(self._last_state[:, n].reshape((1, 3)))
|
||||||
return np.concatenate(vecteurs, axis=0).T
|
etat = np.concatenate(vecteurs, axis=0).T
|
||||||
|
self._last_state = etat
|
||||||
|
return etat
|
||||||
|
|
||||||
def application_action(self, actions):
|
def application_action(self, actions):
|
||||||
for n, alpha in enumerate(actions):
|
for n, alpha in enumerate(actions):
|
||||||
|
@ -221,18 +196,29 @@ class SuperviseurGlobal():
|
||||||
def reset(self):
|
def reset(self):
|
||||||
[superviseur.reset() for superviseur in self.superviseurs]
|
[superviseur.reset() for superviseur in self.superviseurs]
|
||||||
|
|
||||||
|
def reset_rto(self):
|
||||||
|
for superviseur in self.superviseurs:
|
||||||
|
superviseur.reset_rto()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def failed(self):
|
||||||
|
return sum([superviseur._n_echec for superviseur in self.superviseurs])
|
||||||
|
|
||||||
def qualite(self, n_request, beta_retransmission, beta_equite, beta_RTO):
|
def qualite(self, n_request, beta_retransmission, beta_equite, beta_RTO):
|
||||||
|
|
||||||
n_envoies = np.array([
|
n_envoies = np.array([
|
||||||
superviseur._n_envoie for superviseur in self.superviseurs])
|
superviseur._n_envoie for superviseur in self.superviseurs])
|
||||||
n_tokens = np.array([superviseur._n_token for superviseur in self.superviseurs])
|
n_tokens = np.array(
|
||||||
|
[superviseur._n_token for superviseur in self.superviseurs])
|
||||||
RTOs = np.array([superviseur.RTO for superviseur in self.superviseurs])
|
RTOs = np.array([superviseur.RTO for superviseur in self.superviseurs])
|
||||||
|
|
||||||
qualite = 0
|
qualite = 0
|
||||||
qualite -= beta_retransmission * sum(n_tokens)/sum(n_envoies)
|
qualite += beta_retransmission * (sum(n_tokens)/sum(n_envoies))
|
||||||
qualite += beta_equite * \
|
qualite += beta_equite * \
|
||||||
(sum(n_envoies/n_tokens))**2 / \
|
(sum(n_envoies/n_tokens))**2 / \
|
||||||
(len(n_envoies) * sum((n_envoies/n_tokens)**2))
|
(len(n_envoies) * sum((n_envoies/n_tokens)**2))
|
||||||
qualite -= beta_RTO * np.max(RTOs)
|
qualite += beta_RTO * (2-np.max(RTOs))
|
||||||
|
|
||||||
|
if qualite == np.nan:
|
||||||
|
return 0
|
||||||
return qualite
|
return qualite
|
||||||
|
|
26
demo_env.py
Normal file
26
demo_env.py
Normal file
|
@ -0,0 +1,26 @@
|
||||||
|
import socket
|
||||||
|
import time
|
||||||
|
|
||||||
|
from coapthon.client.helperclient import HelperClient
|
||||||
|
from coapthon.client.superviseur import (SuperviseurGlobal,
|
||||||
|
SuperviseurLocalFiltre)
|
||||||
|
from coapthon.utils import parse_uri
|
||||||
|
from utils_learning import MaquetteCoapEnv, RequettePeriodique
|
||||||
|
|
||||||
|
host, port, path = parse_uri("coap://raspberrypi.local/basic")
|
||||||
|
try:
|
||||||
|
tmp = socket.gethostbyname(host)
|
||||||
|
host = tmp
|
||||||
|
except socket.gaierror:
|
||||||
|
pass
|
||||||
|
|
||||||
|
clients = [HelperClient(server=(host, port)) for _ in range(5)]
|
||||||
|
|
||||||
|
environment = MaquetteCoapEnv(
|
||||||
|
clients, SuperviseurLocalFiltre, SuperviseurGlobal, path)
|
||||||
|
|
||||||
|
requests = [RequettePeriodique(client, 2, path, name="Spamer {}".format(
|
||||||
|
n)) for n, client in enumerate(clients)]
|
||||||
|
[request.start() for request in requests]
|
||||||
|
while True:
|
||||||
|
print(environment.step(5*[0]))
|
|
@ -1,14 +1,16 @@
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import socket
|
import socket
|
||||||
import time
|
|
||||||
import threading
|
import threading
|
||||||
|
import time
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
from coapthon.client.helperclient import HelperClient
|
from coapthon.client.helperclient import HelperClient
|
||||||
from coapthon.client.superviseur import SuperviseurLocal, SuperviseurLocalFiltre
|
from coapthon.client.superviseur import (SuperviseurLocal,
|
||||||
|
SuperviseurLocalFiltre)
|
||||||
from coapthon.utils import parse_uri
|
from coapthon.utils import parse_uri
|
||||||
|
|
||||||
N_rep = 50
|
N_rep = 100
|
||||||
N_client = 100
|
N_client = 200
|
||||||
|
|
||||||
host, port, path = parse_uri("coap://raspberrypi.local/basic")
|
host, port, path = parse_uri("coap://raspberrypi.local/basic")
|
||||||
try:
|
try:
|
||||||
|
@ -24,21 +26,24 @@ for client in clients:
|
||||||
client.protocol.superviseur = SuperviseurLocal(client)
|
client.protocol.superviseur = SuperviseurLocal(client)
|
||||||
supers.append(client.protocol.superviseur)
|
supers.append(client.protocol.superviseur)
|
||||||
|
|
||||||
|
|
||||||
def experience(client, N_rep):
|
def experience(client, N_rep):
|
||||||
for n_rep in range(N_rep):
|
for n_rep in range(N_rep):
|
||||||
response = client.get(path)
|
response = client.get(path)
|
||||||
client.stop()
|
client.stop()
|
||||||
|
|
||||||
threads = [threading.Thread(target=experience, args=[client, N_rep], name='Thread-experience-{}'.format(n)) for n, client in enumerate(clients)]
|
|
||||||
|
|
||||||
for thread in threads :
|
threads = [threading.Thread(target=experience, args=[
|
||||||
|
client, N_rep], name='Thread-experience-{}'.format(n)) for n, client in enumerate(clients)]
|
||||||
|
|
||||||
|
for thread in threads:
|
||||||
thread.start()
|
thread.start()
|
||||||
|
|
||||||
for thread in threads :
|
for thread in threads:
|
||||||
thread.join()
|
thread.join()
|
||||||
|
|
||||||
fig, axs = plt.subplots(3, sharex=True)
|
fig, axs = plt.subplots(3, sharex=True)
|
||||||
for n, ax in enumerate(axs) :
|
for n, ax in enumerate(axs):
|
||||||
ax.hist(supers[n].RTTs, 100, density=True)
|
ax.hist(supers[n].RTTs, 100, density=True)
|
||||||
|
|
||||||
axs[-1].set_xlabel('RTT (s)')
|
axs[-1].set_xlabel('RTT (s)')
|
||||||
|
@ -48,4 +53,5 @@ fig.tight_layout()
|
||||||
fig.savefig('demo.png')
|
fig.savefig('demo.png')
|
||||||
|
|
||||||
for n, super in enumerate(supers):
|
for n, super in enumerate(supers):
|
||||||
print("{:<5} | {:.5E} | {:.5E} | {:.5E} | {:0>5} | {:0>5}".format(n, super.min_RTT, super.avg_RTT, super.tau_retransmission, super._n_envoie, super._n_tokken))
|
print("{:<5} | {:.5E} | {:.5E} | {:3^%} | {:0>5} | {:0>5}".format(n, super.min_RTT,
|
||||||
|
super.avg_RTT, super.taux_retransmission, super._n_envoie, super._n_tokken))
|
||||||
|
|
|
@ -1,9 +1,10 @@
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import socket
|
import socket
|
||||||
import time
|
import time
|
||||||
|
from coapthon.client import superviseur_local
|
||||||
|
|
||||||
from coapthon.client.helperclient import HelperClient
|
from coapthon.client.helperclient import HelperClient
|
||||||
from coapthon.client.superviseur import SuperviseurLocal, SuperviseurLocalFiltre
|
from coapthon.client.superviseur_local import SuperviseurLocal, SuperviseurLocalFiltre
|
||||||
from coapthon.utils import parse_uri
|
from coapthon.utils import parse_uri
|
||||||
|
|
||||||
|
|
||||||
|
@ -13,6 +14,7 @@ try:
|
||||||
host = tmp
|
host = tmp
|
||||||
except socket.gaierror:
|
except socket.gaierror:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
print('start client')
|
print('start client')
|
||||||
client = HelperClient(server=(host, port))
|
client = HelperClient(server=(host, port))
|
||||||
print('client started')
|
print('client started')
|
||||||
|
@ -27,8 +29,8 @@ N_rep = 100
|
||||||
for n_rep in range(N_rep):
|
for n_rep in range(N_rep):
|
||||||
# print('rep{}'.format(n_rep))
|
# print('rep{}'.format(n_rep))
|
||||||
response = client.get(path)
|
response = client.get(path)
|
||||||
rtt_l.append(super._RTT_L)
|
rtt_l.append(super.RTT_L)
|
||||||
rtt_s.append(super._RTT_S)
|
rtt_s.append(super.RTT_S)
|
||||||
# time.sleep(1)
|
# time.sleep(1)
|
||||||
# print("{} : \n{}".format(n_rep, response.pretty_print()))
|
# print("{} : \n{}".format(n_rep, response.pretty_print()))
|
||||||
client.stop()
|
client.stop()
|
||||||
|
|
64
limite_requette.py
Normal file
64
limite_requette.py
Normal file
|
@ -0,0 +1,64 @@
|
||||||
|
import socket
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from coapthon.client.helperclient import HelperClient
|
||||||
|
from coapthon.client.superviseur import (SuperviseurGlobal, SuperviseurLocal,
|
||||||
|
SuperviseurLocalFiltre)
|
||||||
|
from coapthon.utils import parse_uri
|
||||||
|
|
||||||
|
host, port, path = parse_uri("coap://raspberrypi.local/basic")
|
||||||
|
try:
|
||||||
|
tmp = socket.gethostbyname(host)
|
||||||
|
host = tmp
|
||||||
|
except socket.gaierror:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def experience(client, N_rep):
|
||||||
|
for n_rep in range(N_rep):
|
||||||
|
response = client.get(path)
|
||||||
|
client.stop()
|
||||||
|
|
||||||
|
|
||||||
|
N_REQUETTE = 20
|
||||||
|
|
||||||
|
results = []
|
||||||
|
|
||||||
|
N_clients = np.linspace(1, 150, 100, dtype=np.int)
|
||||||
|
try:
|
||||||
|
for n_client in N_clients:
|
||||||
|
print("Test à {}".format(n_client))
|
||||||
|
clients = [HelperClient(server=(host, port)) for _ in range(n_client)]
|
||||||
|
super_global = SuperviseurGlobal(clients, SuperviseurLocalFiltre)
|
||||||
|
threads = [threading.Thread(target=experience, args=[
|
||||||
|
client, N_REQUETTE], name='T-{}-{}'.format(n_client, n)) for n, client in enumerate(clients)]
|
||||||
|
for thread in threads:
|
||||||
|
thread.start()
|
||||||
|
for thread in threads:
|
||||||
|
thread.join()
|
||||||
|
results.append(super_global.state)
|
||||||
|
time.sleep(3)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
[thread.join() for thread in threads]
|
||||||
|
[client.close() for client in clients]
|
||||||
|
|
||||||
|
fig, axs = plt.subplots(3, 1, sharex=True)
|
||||||
|
|
||||||
|
|
||||||
|
for idx in range(3):
|
||||||
|
axs[idx].plot(N_clients[0:len(results)], [results[n][idx][0]
|
||||||
|
for n, _ in enumerate(results)])
|
||||||
|
|
||||||
|
axs[0].set_ylabel("""Taux de\nretransmission""")
|
||||||
|
axs[2].set_ylabel("""$\\frac{min_{rtt}}{avg_{rtt}}$""")
|
||||||
|
axs[3].set_ylabel("""$\\frac{rtt_s}{rtt_l}$""")
|
||||||
|
|
||||||
|
axs[-1].set_xlabel("""nombre de requette simultanées""")
|
||||||
|
|
||||||
|
fig.tight_layout()
|
||||||
|
fig.savefig("""n_client_saturation.png""")
|
||||||
|
fig.savefig("""n_client_saturation.svg""")
|
107
utils_learning.py
Normal file
107
utils_learning.py
Normal file
|
@ -0,0 +1,107 @@
|
||||||
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
|
import abc
|
||||||
|
import socket
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from typing import Any, Callable, Iterable, Mapping, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf
|
||||||
|
from tf_agents.environments import (py_environment, tf_environment,
|
||||||
|
tf_py_environment, utils, wrappers)
|
||||||
|
from tf_agents.specs import array_spec
|
||||||
|
from tf_agents.trajectories import time_step as ts
|
||||||
|
|
||||||
|
from coapthon.client.helperclient import HelperClient
|
||||||
|
from coapthon.client.superviseur import (SuperviseurGlobal,
|
||||||
|
SuperviseurLocalFiltre)
|
||||||
|
from coapthon.utils import parse_uri
|
||||||
|
|
||||||
|
|
||||||
|
class RequettePeriodique(threading.Thread):
|
||||||
|
def __init__(self, client: HelperClient, period: float, path: str, group: None = None, target: Optional[Callable[..., Any]] = None, name: Optional[str] = None, args: Iterable[Any] = (), kwargs: Optional[Mapping[str, Any]] = None, *, daemon: Optional[bool] = None) -> None:
|
||||||
|
super().__init__(group=group, target=target, name=name,
|
||||||
|
args=args, kwargs=kwargs, daemon=daemon)
|
||||||
|
self._client = client
|
||||||
|
self._period = period
|
||||||
|
self._path = path
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
while self.period:
|
||||||
|
tf = time.monotonic() + self.period
|
||||||
|
self._client.get(self._path)
|
||||||
|
ts = tf - time.monotonic()
|
||||||
|
if ts > 0:
|
||||||
|
time.sleep(ts)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def period(self):
|
||||||
|
return self._period
|
||||||
|
|
||||||
|
@period.setter
|
||||||
|
def period(self, value):
|
||||||
|
if value >= 0:
|
||||||
|
self._period = value
|
||||||
|
else:
|
||||||
|
raise ValueError
|
||||||
|
|
||||||
|
|
||||||
|
class MaquetteCoapEnv(py_environment.PyEnvironment):
|
||||||
|
def __init__(self, clients: Iterable[HelperClient], superviseur_local_type: type, superviseur_global_type: type, request_path: str, args_reward: Iterable[Any] = (),
|
||||||
|
control_period: float = 30, request_period: Iterable[float] = None) -> None:
|
||||||
|
|
||||||
|
self.clients = clients
|
||||||
|
self.super_g = superviseur_global_type(clients, superviseur_local_type)
|
||||||
|
|
||||||
|
self._action_spec = array_spec.BoundedArraySpec(
|
||||||
|
shape=(len(clients),), dtype=np.float32, minimum=-10, maximum=10, name='action')
|
||||||
|
self._observation_spec = array_spec.BoundedArraySpec(
|
||||||
|
shape=(superviseur_global_type.nombre_mesure, len(clients)), dtype=np.float32, minimum=0, name='observation')
|
||||||
|
self._episode_ended = False
|
||||||
|
self._current_time_step = np.zeros(
|
||||||
|
(3, len(self.clients)), dtype=np.float32)
|
||||||
|
self.control_period = control_period
|
||||||
|
|
||||||
|
self._args_reward = args_reward
|
||||||
|
|
||||||
|
if request_period is None:
|
||||||
|
request_period = [5 for client in clients]
|
||||||
|
|
||||||
|
self.requests = [RequettePeriodique(client, request_period[n], request_path, name="Spamer {}".format(
|
||||||
|
n)) for n, client in enumerate(clients)]
|
||||||
|
[request.start() for request in self.requests]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def request_period(self):
|
||||||
|
return [request.period for request in self.requests]
|
||||||
|
|
||||||
|
def action_spec(self) -> array_spec.BoundedArraySpec:
|
||||||
|
return self._action_spec
|
||||||
|
|
||||||
|
def observation_spec(self) -> array_spec.BoundedArraySpec:
|
||||||
|
return self._observation_spec
|
||||||
|
|
||||||
|
def _reset(self) -> None:
|
||||||
|
etat = np.zeros(
|
||||||
|
(3, len(self.clients)), dtype=np.float32)
|
||||||
|
self._current_time_step = etat
|
||||||
|
self.super_g.reset_rto()
|
||||||
|
return ts.transition(etat, reward=0)
|
||||||
|
|
||||||
|
def _step(self, action: Iterable[float]):
|
||||||
|
self.super_g.application_action(action)
|
||||||
|
self.super_g.reset()
|
||||||
|
|
||||||
|
time.sleep(self.control_period)
|
||||||
|
|
||||||
|
etat = self.super_g.state
|
||||||
|
if self._args_reward == ():
|
||||||
|
recompense = self.super_g.qualite(5*[1], 1000, 1, 1)
|
||||||
|
else:
|
||||||
|
recompense = self.super_g.qualite(5*[1], *self._args_reward)
|
||||||
|
self._current_time_step = etat
|
||||||
|
if self.super_g.failed:
|
||||||
|
return ts.termination(etat, -10000)
|
||||||
|
else:
|
||||||
|
return ts.transition(etat, reward=recompense)
|
Loading…
Reference in a new issue