You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
108 lines
4.0 KiB
Python
108 lines
4.0 KiB
Python
from __future__ import absolute_import, division, print_function
|
|
|
|
import abc
|
|
import socket
|
|
import threading
|
|
import time
|
|
from typing import Any, Callable, Iterable, Mapping, Optional
|
|
|
|
import numpy as np
|
|
import tensorflow as tf
|
|
from tf_agents.environments import (py_environment, tf_environment,
|
|
tf_py_environment, utils, wrappers)
|
|
from tf_agents.specs import array_spec
|
|
from tf_agents.trajectories import time_step as ts
|
|
|
|
from coapthon.client.helperclient import HelperClient
|
|
from coapthon.client.superviseur import (SuperviseurGlobal,
|
|
SuperviseurLocalFiltre)
|
|
from coapthon.utils import parse_uri
|
|
|
|
|
|
class RequettePeriodique(threading.Thread):
|
|
def __init__(self, client: HelperClient, period: float, path: str, group: None = None, target: Optional[Callable[..., Any]] = None, name: Optional[str] = None, args: Iterable[Any] = (), kwargs: Optional[Mapping[str, Any]] = None, *, daemon: Optional[bool] = None) -> None:
|
|
super().__init__(group=group, target=target, name=name,
|
|
args=args, kwargs=kwargs, daemon=daemon)
|
|
self._client = client
|
|
self._period = period
|
|
self._path = path
|
|
|
|
def run(self):
|
|
while self.period:
|
|
tf = time.monotonic() + self.period
|
|
self._client.get(self._path)
|
|
ts = tf - time.monotonic()
|
|
if ts > 0:
|
|
time.sleep(ts)
|
|
|
|
@property
|
|
def period(self):
|
|
return self._period
|
|
|
|
@period.setter
|
|
def period(self, value):
|
|
if value >= 0:
|
|
self._period = value
|
|
else:
|
|
raise ValueError
|
|
|
|
|
|
class MaquetteCoapEnv(py_environment.PyEnvironment):
|
|
def __init__(self, clients: Iterable[HelperClient], superviseur_local_type: type, superviseur_global_type: type, request_path: str, args_reward: Iterable[Any] = (),
|
|
control_period: float = 30, request_period: Iterable[float] = None) -> None:
|
|
|
|
self.clients = clients
|
|
self.super_g = superviseur_global_type(clients, superviseur_local_type)
|
|
|
|
self._action_spec = array_spec.BoundedArraySpec(
|
|
shape=(len(clients),), dtype=np.float32, minimum=-10, maximum=10, name='action')
|
|
self._observation_spec = array_spec.BoundedArraySpec(
|
|
shape=(superviseur_global_type.nombre_mesure, len(clients)), dtype=np.float32, minimum=0, name='observation')
|
|
self._episode_ended = False
|
|
self._current_time_step = np.zeros(
|
|
(3, len(self.clients)), dtype=np.float32)
|
|
self.control_period = control_period
|
|
|
|
self._args_reward = args_reward
|
|
|
|
if request_period is None:
|
|
request_period = [5 for client in clients]
|
|
|
|
self.requests = [RequettePeriodique(client, request_period[n], request_path, name="Spamer {}".format(
|
|
n)) for n, client in enumerate(clients)]
|
|
[request.start() for request in self.requests]
|
|
|
|
@property
|
|
def request_period(self):
|
|
return [request.period for request in self.requests]
|
|
|
|
def action_spec(self) -> array_spec.BoundedArraySpec:
|
|
return self._action_spec
|
|
|
|
def observation_spec(self) -> array_spec.BoundedArraySpec:
|
|
return self._observation_spec
|
|
|
|
def _reset(self) -> None:
|
|
etat = np.zeros(
|
|
(3, len(self.clients)), dtype=np.float32)
|
|
self._current_time_step = etat
|
|
self.super_g.reset_rto()
|
|
return ts.transition(etat, reward=0)
|
|
|
|
def _step(self, action: Iterable[float]):
|
|
self.super_g.application_action(action)
|
|
self.super_g.reset()
|
|
|
|
time.sleep(self.control_period)
|
|
|
|
etat = self.super_g.state
|
|
if self._args_reward == ():
|
|
recompense = self.super_g.qualite(5*[1], 1000, 1, 1)
|
|
else:
|
|
recompense = self.super_g.qualite(5*[1], *self._args_reward)
|
|
self._current_time_step = etat
|
|
if self.super_g.failed:
|
|
return ts.termination(etat, -10000)
|
|
else:
|
|
return ts.transition(etat, reward=recompense)
|