M2_SETI/IA/seti_master-master/code/visualize_outputs.py
2023-01-29 16:56:40 +01:00

81 lines
2.3 KiB
Python

"""A set of functions to vizualise the output space of a neural
network.
"""
import random
import sys
import torch
import matplotlib.pyplot as plt
from onnx2pytorch import convert
def sample2d(model, n_sample=1000, frozen_dim=0, frozen_val=0.5):
dim_1 = []
dim_2 = []
ys = []
for i in range(n_sample):
x1 = random.random()
x2 = random.random()
x3 = frozen_val
if frozen_dim == 0:
x = torch.tensor([x3, x1, x2])
elif frozen_dim == 1:
x = torch.tensor([x2, x3, x1])
elif frozen_dim == 2:
x = torch.tensor([x1, x2, x3])
else:
raise ValueError("Error, frozen_dim must be < 2, got {}".format(
frozen_dim))
y = model(x)
dim_1.append(x1)
dim_2.append(x2)
ys.append(y)
colours = list(map(lambda x: 'red' if x > 0 else 'blue', ys))
return dim_1, dim_2, colours
def sample3d(model, n_sample=1000):
dim_1 = []
dim_2 = []
dim_3 = []
ys = []
for i in range(n_sample):
x1 = random.random()
x2 = random.random()
x3 = random.random()
x = torch.tensor([x1, x2, x3])
y = model(x)
dim_1.append(x1)
dim_2.append(x2)
dim_3.append(x3)
ys.append(y)
colours = list(map(lambda x: 'red' if x > 0 else 'blue', ys))
return dim_1, dim_2, dim_3, colours
def plot2d(x1s, x2s, colours):
for j, (x1, x2) in enumerate(zip(x1s, x2s)):
plt.scatter(x1, x2, color=colours[j])
# plt.title("ALPHA=0.5, BETA=0.25, DELTA1=0.3, DELTA2=0.7, EPSILON=0.09")
plt.show()
def plot3d(x1s, x2s, x3s, colours):
fig = plt.figure()
ax = fig.add_subplot(projection='3d')
for j, (x1, x2, x3) in enumerate(zip(x1s, x2s, x3s)):
ax.scatter(x1, x2, x3, color=colours[j])
ax.set_xlabel('distance')
ax.set_ylabel('speed')
ax.set_zlabel('angle')
plt.show()
if __name__ == '__main__':
n_sample = int(sys.argv[1])
frozen_dim = int(sys.argv[2])
frozen_val = float(sys.argv[3])
model = convert("network.onnx")
dim_1, dim_2, colours = sample2d(model, n_sample, frozen_dim, frozen_val)
plot2d(dim_1, dim_2, colours)
dim_1, dim_2, dim_3, colours = sample3d(model, n_sample)
plot3d(dim_1, dim_2, dim_3, colours)