M2_SETI/IA/seti_master-master/code/tutorial.ipynb

1500 lines
416 KiB
Text
Raw Permalink Normal View History

2023-01-29 16:56:40 +01:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Formal verification of deep neural networks: a tutorial\n",
"\n",
"The aim of this tutorial is to give a glimpse on the practical side of Formal Verification for Deep Neural Networks.\n",
"This tutorial is divided in four part:\n",
"1. Verification by hand\n",
"2. Small problem verification\n",
"3. Real use case application\n",
"4. Image classification\n",
"\n",
"The tutorial material was written by Augustin Lemesle (CEA List) with material provided by Serge Durand (CEA) based on a previous tutorial created by Julien Girard-Satabin (CEA LIST/INRIA), Zakaria Chihani (CEA LIST) and Guillaume Charpiat (INRIA)."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Part 1: Verification by hand\n",
"\n",
"This first part aims to give a rough overview of the challenges posed by the verification of a neural network. In the first part of the lesson you should have seen a technique called Abstract Interpretation which can leverages intervals to estimate the output of a network. You should have seen an example by hand of this method. In this part, we will developp a small class that will calculate the output automatically with intervals.\n",
"\n",
"### Step 1: Encode the network\n",
"\n",
"![image](imgs/network.png)\n",
"\n",
"With the above network create a function `network(x1, x2)` which reproduces its comportement. It must pass the tests created in `test_network`. For the relu layer, its function is already implemented here."
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"def relu(x):\n",
" return max(0, x)\n",
"\n",
"def network(x1, x2):\n",
" # apply all the operations\n",
" x3 = 2*x1 + x2 + 1\n",
" x4 = -x1 + x2\n",
" x3p = relu(x3)\n",
" x4p = relu(x4)\n",
" x5 = 2*x4p - 0.5*x3p - 1\n",
" x6 = x4p - x3p + 2\n",
" y = x5 - x6\n",
" return y"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"def test_network():\n",
" assert network(0, 0) == -2.5\n",
" assert network(0, -1) == -3\n",
" assert network(0, 1) == -1\n",
" assert network(-1, 1) == -1\n",
" assert network(-1, 0) == -2\n",
" assert network(-1, -1) == -3\n",
" assert network(1, 0) == -1.5\n",
" assert network(1, -1) == -2\n",
" assert network(1, 1) == -1\n",
" \n",
"test_network()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Step 2: Create an Interval\n",
"\n",
"Following the rules of interval arithmetic write a class representing an Interval by overriding Python operators. A skeleton is avalaible below.\n",
"\n",
"Intervals rules:\n",
"- $[l, u] + \\lambda = [l + \\lambda, u + \\lambda]$\n",
"- $[l, u] + [l', u'] = [l + l', u + u']$\n",
"- $-[l, u] = [-u, -l]$\n",
"-$[l, u] - [l', u'] = [l - u', u - l']$\n",
"- $[l, u] * \\lambda =$\n",
" - si $\\lambda >= 0$ -> $[\\lambda * l, \\lambda * u]$\n",
" - si $\\lambda < 0$ -> $[\\lambda * u, \\lambda * l]$\n",
" \n",
"We will also need to update the relu for it to work on intervals.\n",
"\n",
"Some tests are available for you to check if the class implementation is correct."
]
},
{
"cell_type": "code",
"execution_count": 64,
"metadata": {},
"outputs": [],
"source": [
"class Interval:\n",
" def __init__(self, lower, upper):\n",
" self.lower = lower\n",
" self.upper = upper\n",
" \n",
" def __add__(self, other):\n",
" if isinstance(other, Interval):\n",
" return Interval(self.lower + other.lower, self.upper + other.upper)\n",
" else:\n",
" return Interval(self.lower + other, self.upper + other)\n",
" \n",
" def __sub__(self, other):\n",
" if isinstance(other, Interval):\n",
" return self.__add__(other.__neg__())\n",
" else:\n",
" return Interval(self.lower - other, self.upper - other)\n",
" \n",
" \n",
" def __neg__(self):\n",
" return Interval(-self.upper,-self.lower)\n",
" \n",
" def __mul__(self, other):\n",
" if isinstance(other, Interval):\n",
" return Interval(self.lower, self.upper)\n",
" else:\n",
" if(other < 0):\n",
" return Interval(self.upper * other,self.lower * other)\n",
" else:\n",
" return Interval(self.lower * other, self.upper * other)\n",
" \n",
" \n",
" def __rmul__(self, other):\n",
" return self.__mul__(other)\n",
" \n",
" def __str__(self):\n",
" return f\"[{self.lower}, {self.upper}]\"\n",
"\n",
" def __repr__(self):\n",
" return self.__str__()\n",
"\n",
" def __eq__(self, other):\n",
" return self.lower == other.lower and self.upper == other.upper"
]
},
{
"cell_type": "code",
"execution_count": 65,
"metadata": {},
"outputs": [],
"source": [
"def relu(x):\n",
" if isinstance(x, Interval):\n",
" lower = max(0, x.lower)\n",
" upper = max(0, x.upper)\n",
" return Interval(lower, upper)\n",
" else:\n",
" return max(0, x)"
]
},
{
"cell_type": "code",
"execution_count": 66,
"metadata": {},
"outputs": [],
"source": [
"def test_interval():\n",
" assert Interval(0, 1) == Interval(0, 1)\n",
" assert -Interval(0, 1) == Interval(-1, 0)\n",
" assert Interval(0, 1) + Interval(1, 2) == Interval(1, 3)\n",
" assert Interval(0, 1) - Interval(1, 2) == Interval(-2, 0)\n",
" assert Interval(-1, 2) * 3 == Interval(-3, 6)\n",
" assert Interval(-1, 2) * -3 == Interval(-6, 3)\n",
" assert relu(Interval(-2, 3)) == Interval(0, 3)\n",
" \n",
"test_interval()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Step 3: Run the network with intervals\n",
"\n",
"At this point you should be able to run the network using the interval class and see the output reached."
]
},
{
"cell_type": "code",
"execution_count": 67,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[-7.0, 5.0]\n"
]
}
],
"source": [
"print(network(Interval(-1, 1), Interval(-1, 1)))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Bonus step: To go further\n",
"\n",
"- Reproduce the first neural network from the slides to confirm the results\n",
"- Implement a class for an AffineForm to compute more precise outputs"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"*****\n",
"\n",
"## Part 2: Small problem verification\n",
"\n",
"We provided a toy problem representative of current challenges in neural network verification. We also trained a deep neural network to answer this problem.\n",
"\n",
"The goal of this section for the participants is to formally verify that the neural network is _safe_, using the bundled tools."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Problem definition\n",
"\n",
"This toy problem is inspired by the Airborne Collision Avoidance System for Unmanned vehicles (ACAS-Xu) specification and threat model. \n",
"\n",
"![problem formulation](imgs/problem_small.png)\n",
"\n",
"Let A be a Guardian, and B a Threat.\n",
"The goal for the Guardian is to send an ALARM when the Threat arrives too close.\n",
"\n",
"The Guardian has access to the following data:\n",
"* The distance from B to A, $d = ||\\vec{d}||$\n",
"* the speed of B, $v =||\\vec{v} ||$\n",
"* the angle $\\theta$ between $\\vec{d}$ and $\\vec{v}$\n",
"\n",
"All values are normalized in $\\left[0,1\\right]$, the angle is not oriented.\n",
"\n",
"We want to define three main ”zones”:\n",
"1. a **”safe”** zone: when B is in this zone, it is not considered a threat for any $||\\vec{d}|| > \\delta_2$, no ALARM is issued.\n",
"2. a **”suspicious”** zone: when B is in this zone, if $||\\vec{v}|| > \\alpha$ and $\\theta < \\beta$\n",
" then a ALARM should be issued. Else, no ALARM is issued.\n",
"3. a **”danger”** zone: when B is in this zone, a ALARM is issued no matter what. When $||\\vec{d}|| < \\delta_1$, B is in the danger zone.\n",
"\n",
" \n",
"### Solving this problem with a neural network\n",
"\n",
"A neural network was pre-trained to solve this task (all files used to this end are available). \n",
"It has 5 fully connected layers, the first layer takes 3 inputs and the last layer has 1 output. There are four hidden layers: first and second hidden layers are of size 10, the third is size 5 and the fourth is size 2. We used ReLUs as activation functions. \n",
"\n",
"The network was trained to output a positive value if there is an alarm, and a negative value if there is no alarm. For a detailed summary of hyperparameters, you may check the defaults in `train.py`. It achieved 99.9% accuracy on the test set, with a total training set of 100000 samples.\n",
"\n",
"The specification used to train the network is based on :\n",
"- $\\alpha = 0.5$\n",
"- $\\beta = 0.25$\n",
"- $\\delta_1 = 0.3$\n",
"- $\\delta_2 = 0.7$\n",
"\n",
"We will aim to prouve that it respects these values.\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Create the safety property\n",
"\n",
"The trained network is in the repository, under the filename `network.onnx`. Your goal is to learn how to write a safety property and launch different tools on the network.\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Step 1: Visualization \n",
"\n",
"You can first visualize the network answer on the output space by sampling inputs,\n",
"using the function below (**careful, it may take time if you input a big number of samples!**).\n",
"\n",
"`sample2d` is faster but sample only on a 2d slice, `sample3d` gives a full representation of the output space.\n",
"\n",
"Blue color denotes no alert, red color denotes an alert.\n"
]
},
{
"cell_type": "code",
"execution_count": 68,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e8da9a4d8fc34662a3ec0b9c451ce959",
"version_major": 2,
"version_minor": 0
},
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAoAAAAHgCAYAAAA10dzkAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAA9hAAAPYQGoP6dpAACgZklEQVR4nO29f5BlRZnn/VRf7ALD/kH9ALvrFnTAoG7ouIZjFCHaC6gRxijSbsvCitG27+6AM3aHlDrlMlIzTc2ouMpYMIBj6CzDzBZNvVVUoztKiFNu9zutKLs7C7MGzbirNGPTdrc/B9hBCm+R7x9nTvepU+dkPk/mkz/Ovc8n4kbB7XvvyZPnycxvPvnkk31KKQWCIAiCIAhCz7AmdgEEQRAEQRCEsIgAFARBEARB6DFEAAqCIAiCIPQYIgAFQRAEQRB6DBGAgiAIgiAIPYYIQEEQBEEQhB5DBKAgCIIgCEKPIQJQEARBEAShxxABKAiCIAiC0GOIABQEQRAEQegxRAAKgiAIgiD0GCIABUEQBEEQegwRgIIgCIIgCD2GCEBBEARBEIQeQwSgIAiCIAhCjyECUBAEQRAEoccQASgIgiAIgtBjiAAUBEEQBEHoMUQACoIgCIIg9BgiAAVBEARBEHoMEYCCIAiCIAg9hghAQRAEQRCEHkMEoCAIgiAIQo8hAlAQBEEQBKHHEAEoCIIgCILQY4gAFARBEARB6DFEAAqCIAiCIPQYIgAFQRAEQRB6DBGAgiAIgiAIPYYIQEEQBEEQhB5DBKAgCIIgCEKPIQJQEARBEAShxxABKAiCIAiC0GOIABQEQRAEQegxRAAKgiAIgiD0GCIABUEQBEEQegwRgIIgCIIgCD2GCEBBEARBEIQeQwSgIAiCIAhCjyECUBAEQRAEoccQASgIgiAIgtBjiAAUBEEQBEHoMUQACoIgCIIg9BgiAAVBEARBEHoMEYCCIAiCIAg9hghAQRAEQRCEHkMEoCAIgiAIQo8hAlAQBEEQBKHHEAEoCIIgCILQY4gAFARBEARB6DFEAAqCIAiCIPQYIgAFQRAEQRB6DBGAgiAIgiAIPYYIQEEQBEEQhB5DBKAgCIIgCEKPIQJQEARBEAShxzgtdgGazAsvvAA/+tGPYN26ddDX1xe7OIIgCIIgIFBKwTPPPAObN2+GNWt60xcmAtCBH/3oRzA6Ohq7GIIgCIIgWHDkyBFot9uxixEFEYAOrFu3DgAyA1q/fn3k0giCIAiCgOHpp5+G0dHRk+N4LyIC0IF82Xf9+vUiAAVBEAShYfRy+FZvLnwLgiAIgiD0MCIABUEQBEEQegwRgIIgCIIgCD2GCEBBEARBEIQeQwSgIAiCIAhCjyECUBAEQRAEoccQASgIgiAIgtBjiAAUBEEQBEHoMSQRtNAYlpcBDh4EOHYMYNMmgK1bAVqt2KUSBEEQhObRNR7Av/mbv4F3vOMdsHnzZujr64MvfelLxu8cOHAAXvva10J/fz/82q/9Gtx1113eyynYsW8fwJYtAJdeCnD11dnfLVuy94UwLC8DHDgAcM892d/l5d4sgw+69b4EQUiXrhGA//RP/wT/8l/+S7jjjjtQnz98+DC8/e1vh0svvRQeeeQRGB8fh9/6rd+CBx54wHNJBSr79gFccQXAk0+ufP/o0ez9XhSBoQVDCgI8hTL4ILX7EjHaHORZCU6oLgQA1H333af9zEc/+lH1yle+csV7V111lXrrW9+Kvs5TTz2lAEA99dRTNsUUEHQ6SrXbSgFUv/r6lBodzT5X/M7+/Urt3Zv9Lf5bN7CwsLpO2u3sfV/X6+urrvu+Pn/XTa0MWCj2l9p9hbYtwR55Vm7I+K1UzwrArVu3quuuu27Fe3feeadav3597Xeee+459dRTT518HTlypOcNyDf799eLv+Jr//7s893eKYYWDDYCnJsUyoCFYn+p3VcqYrTbJ3Ac1D2r/NUt/Z1PRAAq1TVLwFSOHz8OZ5999or3zj77bHj66afhl7/8ZeV3brrpJtiwYcPJ1+joaIiiJo3vJYhjx/Cf6/al4uVlgOuuy7r4Mvl74+O8z+DgwdX1Wb7ukSPZ53yRQhkwUO0vpfuKYVtVmJbDZclT/6xyrr22N+tGoNGzAtCG3/u934Onnnrq5OvIkSOxixSVELFLmzbhPnfWWfYDWFMGlRiCgSLAbdHV//IywDe+4b8MrtgIqBB1iyUFMWoS0B/9aFqxkrEwPSsAgJ/9DOATnwhTHqG59KwAfOlLXwonTpxY8d6JEydg/fr1cMYZZ1R+p7+/H9avX7/i1auE8rZt3QrQbgP09VX/e18fQO6ItRnAUgvA1xFDMGAFOPZzZXT1n//bxz/utwwc2Ago33VLIbYYNQlopQA+85nu9e5TwD6DP/mTdCezQhr0rAB8/etfD98ouRb++q//Gl7/+tdHKlFzCLlc1GoB3Hpr9t9lEZj//y23APz4x7jfK3aeIZeMObyMMQQDVoBv3Ur/bV39v+td2cvk6bApgw+Pr42A8lm3VGKLUYxXq4qQy9OpgH0GP/sZj8e2KSskggWxgxC5eOaZZ9TDDz+sHn74YQUA6rOf/ax6+OGH1T/8wz8opZS6/vrr1Y4dO05+/vHHH1cvfvGL1cTEhHrsscfUHXfcoVqtlvra176GvmavBpFSN2ZwUBVcPzp6KtiZWiaOAHxssDrXxpS8zHXB3742DeQB5+XrumwOMNU/9kUtg69NQrZtwkfd2hDLtnL27nW3Bdv+pmmbTjodpQYGcHWyd6/btbp5U12vjt9FukYA7t+/XwHAqtfOnTuVUkrt3LlTXXzxxau+85rXvEatXbtWnXfeeerP//zPSdfsVQPCdtaunU8ZXUdNHcCwA/biYvU1sR0j987KWILBJMCpYOvf9KKUwecuVxcBxV23tuhsC0Cp8XE7gYQRWBz2sHs3vXxNFThTU35FsVLmncbz82y3E4VeHb+LdI0AjEGvGlAMDyAGijjCitjyTLvdVmpiQi8k5ueze5+ZUWp4uP63bb0qsQRD1UBu6z3h8Pi8613464VIueIizlPxQlXZVqtlL5CwAsskoCkvbPlSSXtjQ6ej1OCgP3vGeOhbLaXm5njvKyS9On4XEQHoQK8aUOzlIh1YcbS46D7Q6DpGyudthHIKgsHFe8LlAcQO0qEmLal481zIbWt8vL59YwQSVWDVCWjqCyu4U8rBaINPAUtpn02y7SK9On4XEQHoQC8bUCqxS1VgxJFPAUh9cS+Vh8B18OHy+GAHadewBYrgTkGcu+IqkDAepIGBrB0Wf6NOQOded6y9mMqHFTiTk2k/Q44JR26vMzNKTU9nfycn8W1weFippSVfd+iPXh6/c0QAOtDrBtRkbwfHEiTXK/RSuStc3hNTzBln/WEH/Olp+3jPbsLVY0rxIJXrsk5AVz0H2/JR23/Kz9tlwmFTp1WvoaF066eOXh+/lRIB6IQYUHO9HVxLkC6vJiwzVcG5nKqbRNQtQZZfGA8qxuNYFetmiveMNej5bneuHlOKwKLUZX7fu3fjy1dVV9T2H/t5+8C0ycOmP2tS/cj4LQLQCTGg5sIZdG7bWTatw8zh3gVeJ2a44/aoMWamz8US8FWieWiINyA/pAfQpi6xvz81Ve29nZ+nt3/O5x174syVhimF9mCLjN8iAJ0QA0oLaqfqugTpIh59LpX7HlxCbajwsdkIs8uV+gq5hG/y2kxM8FzHte5tJ1jYusSUb3Cw3nsLoNRll8V53imEFPhcAWlKSIuM30r17EkgQtpQs8/bHOm2fTvAvfcCjIysfL/dBpiYyE5iqDp9pK8v+/fy91ot3L0BAHz2s9n1uQlxtF1+goWOdjt7ZranBzz/PMBttwG86lXZsFKmeAoMpd63bwd44gmA/fsB9u4FmJ52P9kg1BnEuhN4cj7zmcymXcGewFNX97rv61hYwNmLqXx5HVXVVf7eV75y6rcouDzvkKcP6fBpszHP5BaIxFagTUZmEH6gzpA5dqRig86Lnrvy9+bm8B6PwUGa5wrj1TN5h6am+LyBExP6+3vJS+w9HBMTZq8clwc15gkUVLBem+F
"text/html": [
"\n",
" <div style=\"display: inline-block;\">\n",
" <div class=\"jupyter-widgets widget-label\" style=\"text-align: center;\">\n",
" Figure\n",
" </div>\n",
" <img src='data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAoAAAAHgCAYAAAA10dzkAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAA9hAAAPYQGoP6dpAACgZklEQVR4nO29f5BlRZnn/VRf7ALD/kH9ALvrFnTAoG7ouIZjFCHaC6gRxijSbsvCitG27+6AM3aHlDrlMlIzTc2ouMpYMIBj6CzDzBZNvVVUoztKiFNu9zutKLs7C7MGzbirNGPTdrc/B9hBCm+R7x9nTvepU+dkPk/mkz/Ovc8n4kbB7XvvyZPnycxvPvnkk31KKQWCIAiCIAhCz7AmdgEEQRAEQRCEsIgAFARBEARB6DFEAAqCIAiCIPQYIgAFQRAEQRB6DBGAgiAIgiAIPYYIQEEQBEEQhB5DBKAgCIIgCEKPIQJQEARBEAShxxABKAiCIAiC0GOIABQEQRAEQegxRAAKgiAIgiD0GCIABUEQBEEQegwRgIIgCIIgCD2GCEBBEARBEIQeQwSgIAiCIAhCjyECUBAEQRAEoccQASgIgiAIgtBjiAAUBEEQBEHoMUQACoIgCIIg9BgiAAVBEARBEHoMEYCCIAiCIAg9hghAQRAEQRCEHkMEoCAIgiAIQo8hAlAQBEEQBKHHEAEoCIIgCILQY4gAFARBEARB6DFEAAqCIAiCIPQYIgAFQRAEQRB6DBGAgiAIgiAIPYYIQEEQBEEQhB5DBKAgCIIgCEKPIQJQEARBEAShxxABKAiCIAiC0GOIABQEQRAEQegxRAAKgiAIgiD0GCIABUEQBEEQegwRgIIgCIIgCD2GCEBBEARBEIQeQwSgIAiCIAhCjyECUBAEQRAEoccQASgIgiAIgtBjiAAUBEEQBEHoMUQACoIgCIIg9BgiAAVBEARBEHoMEYCCIAiCIAg9hghAQRAEQRCEHkMEoCAIgiAIQo8hAlAQBEEQBKHHEAEoCIIgCILQY4gAFARBEARB6DFEAAqCIAiCIPQYIgAFQRAEQRB6DBGAgiAIgiAIPYYIQEEQBEEQhB5DBKAgCIIgCEKPIQJQEARBEAShxzgtdgGazAsvvAA/+tGPYN26ddDX1xe7OIIgCIIgIFBKwTPPPAObN2+GNWt60xcmAtCBH/3oRzA6Ohq7GIIgCIIgWHDkyBFot9uxixEFEYAOrFu3DgAyA1q/fn3k0giCIAiCgOHpp5+G0dHRk+N4LyIC0IF82Xf9+vUiAAVBEAShYfRy+FZvLnwLgiAIgiD0MCIABUEQBEEQegwRgIIgCIIgCD2GCEBBEARBEIQeQwSgIAiCIAhCjyECUBAEQRAEoccQASgIgiAIgtBjiAAUBEEQBEHoMSQRtNAYlpcBDh4EOHYMYNMmgK1bAVqt2KUSBEEQhObRNR7Av/mbv4F3vOMdsHnzZujr64MvfelLxu8cOHAAXvva10J/fz/82q/9Gtx1113eyynYsW8fwJYtAJdeCnD11dnfLVuy94UwLC8DHDgAcM892d/l5d4sgw+69b4EQUiXrhGA//RP/wT/8l/+S7jjjjtQnz98+DC8/e1vh0svvRQeeeQRGB8fh9/6rd+CBx54wHNJBSr79gFccQXAk0+ufP/o0ez9XhSBoQVDCgI8hTL4ILX7EjHaHORZCU6oLgQA1H333af9zEc/+lH1yle+csV7V111lXrrW9+Kvs5TTz2lAEA99dRTNsUUEHQ6SrXbSgFUv/r6lBodzT5X/M7+/Urt3Zv9Lf5bN7CwsLpO2u3sfV/X6+urrvu+Pn/XTa0MWCj2l9p9hbYtwR55Vm7I+K1UzwrArVu3quuuu27Fe3feeadav3597Xeee+459dRTT518HTlypOcNyDf799eLv+Jr//7s893eKYYWDDYCnJsUyoCFYn+p3VcqYrTbJ3Ac1D2r/NUt/Z1PRAAq1TVLwFSOHz8OZ5999or3zj77bHj66afhl7/8ZeV3brrpJtiwYcPJ1+joaIiiJo3vJYhjx/Cf6/al4uVlgOuuy7r4Mvl74+O8z+DgwdX1Wb7ukSPZ53yRQhkwUO0vpfuKYVtVmJbDZclT/6xyrr22N+tGoNGzAtCG3/u934Onnnrq5OvIkSOxixSVELFLmzbhPnfWWfYDWFMGlRiCgSLAbdHV//IywDe+4b8MrtgIqBB1iyUFMWoS0B/9aFqxkrEwPSsAgJ/9DOATnwhTHqG59KwAfOlLXwonTpxY8d6JEydg/fr1cMYZZ1R+p7+/H9avX7/i1auE8rZt3QrQbgP09VX/e18fQO6ItRnAUgvA1xFDMGAFOPZzZXT1n//bxz/utwwc2Ago33VLIbYYNQlopQA+85nu9e5TwD6DP/mTdCezQhr0rAB8/etfD98ouRb++q//Gl7/+tdHKlFzCLlc1GoB3Hpr9t9lEZj//y23APz4x7jfK3aeIZeMObyMMQQDVoBv3Ur/bV39v+td2cvk6bApgw+Pr42A8lm3VGKLUYxXq4qQy9OpgH0GP/sZj8e2KSskggWxgxC5eOaZZ9TDDz+sHn74YQUA6rOf/ax6+OGH1T/8wz8opZS6/vrr1Y4dO05+/vHHH1cvfvGL1cTEhHrsscfUHXfcoVqtlvra176GvmavBpFSN2ZwUBVcPzp6KtiZWiaOAHxssDrXxpS8zHXB3742DeQB5+XrumwOMNU/9kUtg69NQrZtwkfd2hDLtnL27nW3Bdv+pmmbTjodpQYGcHWyd6/btbp5U12vjt9FukYA7t+/XwHAqtfOnTuVUkrt3LlTXXzxxau+85rXvEatXbtWnXfeeerP//zPSdfsVQPCdtaunU8ZXUdNHcCwA/biYvU1sR0j987KWILBJMCpYOvf9KKUwecuVxcBxV23tuhsC0Cp8XE7gYQRWBz2sHs3vXxNFThTU35FsVLmncbz82y3E4VeHb+LdI0AjEGvGlAMDyAGijjCitjyTLvdVmpiQi8k5ueze5+ZUWp4uP63bb0qsQRD1UBu6z3h8Pi8613464VIueIizlPxQlXZVqtlL5CwAsskoCkvbPlSSXtjQ6ej1OCgP3vGeOhbLaXm5njvKyS9On4XEQHoQK8aUOzlIh1YcbS46D7Q6DpGyudthHIKgsHFe8LlAcQO0qEmLal481zIbWt8vL59YwQSVWDVCWjqCyu4U8rBaINPAUtpn02y7SK9On4XEQHoQC8bUCqxS1VgxJFPAUh9cS+Vh8B18OHy+GAHadewBYrgTkGcu+IqkDAepIGBrB0Wf6NOQOded6y9mMqHFTiTk2k/Q44JR26vMzNKTU9nfycn8W1weFippSVfd+iPXh6/c0QAOtDrBtRkbwfHEiTXK/RSuStc3hNTzBln/WEH/Olp+3jPbsLVY0rxIJXrsk5AVz0H2/JR23/Kz9tlwmFTp1WvoaF066eOXh+/lRIB6IQYUHO9HVxLkC6vJiwzVcG5nKqbRNQtQZZfGA8qxuNYFetmiveMNej5bneuHlOKwKLUZX7fu3fjy1dVV9T2H/t5+8C0ycOmP2tS/cj4LQLQCTGg5sIZdG7bWTatw8zh3gVeJ2a44/aoMWamz8US8FWieWiINyA/pAfQpi6xvz81Ve29nZ+nt3/O5x174syVhimF9mCLjN8iAJ0QA0oLaqfqugTpIh59LpX7HlxCbajwsdkIs8uV+gq5hG/y2kxM8FzHte5tJ1jYusSUb3Cw3nsLoNRll8V53imEFPhcAWlKSIuM30r17EkgQtpQs8/bHOm2fTvAvfcCjIysfL/dBpiYyE5iqDp9pK8v+/fy91ot3L0BAHz2s9n1uQlxtF1+goWOdjt7ZranBzz/PMBttwG86lXZsFKmeAoMpd63bwd44gmA/fsB9u4FmJ52P9kg1BnEuhN4cj7zmcymXcGewFNX97rv61hYwNmLqXx5HVXVVf7eV75y6rcouDzvkKcP6fBpszHP5BaIxFagTUZmEH6gzpA5dqRig86Lnrvy9+bm8B6PwUGa5wrj1TN5h6am+Ly
" </div>\n",
" "
],
"text/plain": [
"Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c73225e37797417c9071d0a431d9e5cf",
"version_major": 2,
"version_minor": 0
},
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAoAAAAHgCAYAAAA10dzkAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOydeXwU9f3/XzOzOTb3JoEQIEAubkhIIJAEEAhKtVK/tPi1apUKtn7tt2jV1mqrtR61v2o9Wmtrv61ptYc9rIhUC0oAucIZchEgm4MEAoQku7k32d2Z+f2xmcnO7szsbg4E834+HlEy52c3ye5r38frzYiiKIIgCIIgCIIYM7Cf9QIIgiAIgiCIKwsJQIIgCIIgiDEGCUCCIAiCIIgxBglAgiAIgiCIMQYJQIIgCIIgiDEGCUCCIAiCIIgxBglAgiAIgiCIMQYJQIIgCIIgiDEGCUCCIAiCIIgxBglAgiAIgiCIMQYJQIIgCIIgiDEGCUCCIAiCIIgxBglAgiAIgiCIMQYJQIIgCIIgiDEGCUCCIAiCIIgxBglAgiAIgiCIMQYJQIIgCIIgiDEGCUCCIAiCIIgxBglAgiAIgiCIMQYJQIIgCIIgiDEGCUCCIAiCIIgxBglAgiAIgiCIMQYJQIIgCIIgiDEGCUCCIAiCIIgxBglAgiAIgiCIMQYJQIIgCIIgiDEGCUCCIAiCIIgxBglAgiAIgiCIMQYJQIIgCIIgiDEGCUCCIAiCIIgxBglAgiAIgiCIMQYJQIIgCIIgiDEGCUCCIAiCIIgxBglAgiAIgiCIMQYJQIIgCIIgiDEGCUCCIAiCIIgxBglAgiAIgiCIMQYJQIIgCIIgiDEGCUCCIAiCIIgxBglAgiAIgiCIMQYJQIIgCIIgiDEGCUCCIAiCIIgxBglAgiAIgiCIMQYJQIIgCIIgiDEGCUCCIAiCIIgxBglAgiAIgiCIMQYJQIIgCIIgiDEGCUCCIAiCIIgxBglAgiAIgiCIMQYJQIIgCIIgiDEGCUCCIAiCIIgxBglAgiAIgiCIMQYJQIIgCIIgiDEGCUCCIAiCIIgxBglAgiAIgiCIMQYJQIIgCIIgiDEGCUCCIAiCIIgxBglAgiAIgiCIMQYJQIIgCIIgiDEGCUCCIAiCIIgxBglAgiAIgiCIMQYJQIIgCIIgiDEGCUCCIAiCIIgxBglAgiAIgiCIMQYJQIIgCIIgiDEGCUCCIAiCIIgxBglAgiAIgiCIMQYJQIIgCIIgiDEGCUCCIAiCIIgxBglAgiAIgiCIMQYJQIIgCIIgiDEGCUCCIAiCIIgxBglAgiAIgiCIMQYJQIIgCIIgiDEGCUCCIAiCIIgxBglAgiAIgiCIMQYJQIIgCIIgiDEGCUCCIAiCIIgxBglAgiAIgiCIMQYJQIIgCIIgiDEGCUCCIAiCIIgxBglAgiAIgiCIMQYJQIIgCIIgiDEGCUCCIAiCIIgxhuGzXgBBEIQoiuB5HgDAcRwYhvmMV0QQBPH5hgQgQRCfKYIgwOFwwGazQRRFsCyLoKAgcBwHg8EAlmVJEBIEQYwwjCiK4me9CIIgxh5S1M/pdMoiUHo5EgQBAMAwjCwIDQYDOI4jQUgQBDECkAAkCOKKI4oiHA6HnPZlGAZ2u13+t3SM9EWCkCAIYmQhAUgQxBVFEATY7XYIgiCLN1EUvQSgJ56CsL6+HvHx8TCZTDAYDCQICYIgAoBqAAmCuCJIKV8p1esp1CQhqAXDMPLxHMeho6MD0dHREAQB/f396OvrA8uyYFmWBCFBEIQPSAASBDHqCIIAi8WC7u5uJCQkjIgokwQhx3EABiOEPM+D53n09/fLKWMShARBEEpIABIEMWpI6VqHwwGLxYKWlhYkJiaO6PUlJEHIsqy8z73RxF0wSmLQYDAoIosEQRBjBRKABEGMCqIowul0wul0AgBYltVN8QaKvyljT0HodDrhcDi8BKEkCkkQEgQxFiABSBDEiCNF/Xiel0WYL8E22gQiCCUfQillTBAE8XmDBCBBECOGp7efe72dngC02Wwwm80IDQ2FyWRCVFSUT+E1XEHpSxAC8KofJEFIEMTnBRKABEGMCJ7efv52+TY3N6OiogJxcXHo7u7G+fPnIQgCoqOjYTKZYDKZEBkZOeppWS1B6HA4ZIsaEoQEQXxeIAFIEMSwUfP288RTAPI8jzNnzuDChQuYM2cO4uLi5P09PT2wWq2wWq1oaGgAAMTExMBkMiEmJgYRERGjnlJWE4RSaluKEDIMQ4KQIIhrEhKABEEMGV/efu64C7aenh6UlpaCZVnk5eXBaDQqRFVERAQiIiKQlJQEURTR1dWF9vZ2WCwW1NXVySIrODgY0dHRCAsLuyIRQslyRnrsgiCgubkZjY2NyMzM9BKEUpcxQRDE1QYJQIIghoSvlK8n0r6mpiZUVVUhKSkJ06dPV3QHi6LodQ2GYRAVFYWoqChMmTIFgiCgq6sLVVVV6OrqwtGjR2EwGOR0cUxMDIxG4xUThCzLQhAEcBwnRwjtdrscPVTrMiYIgvisIQFIEETASFE/vZSvJ4IgoLe3F6dPn0ZmZibGjRs3pHuzLCtH/eLj45GYmIiOjg5YrVZcvHgRZ86cQXBwsCwITSYTQkNDh3SvQNCKEHoKQs+UMQlCgiA+C0gAEgThN2refv4ImM7OTpw+fRqiKGLp0qUjJsiktLMk9ACXOJUEYVNTE06fPi13F0tfwcHBI3J/PTynlACDtZJaU0pIEBIEcaUgAUgQhF9I0SxBEAAoZ/NqIYoiGhsbUV1djYSEBFit1hETf1r35jgOsbGxiI2NBQA4nU60t7fLDSUnT55EeHi4nC42mUwICgoakTX5WisJQoIgrhZIABIEoYt7KjOQlK/D4UBlZSXa29uRnZ0NURTR3t4+4mvzhcFgQHx8POLj4+V1SYKwvr4elZWViIiIUNQQGgz+vzQOpRNZTxD29/fr2s6QICQIYiQgAUgQhCaBNnpIWK1WlJWVITIyEvn5+QgODobFYpGjhyPBUIVQUFAQxo0bJ9cg2u122XLGbDajr68PkZGRshiMiYlR1PaNBu6CkOM42YNQFEUvQShNKTEYDH7/PAiCIDwhAUgQhCru49z8FRqiKKK+vh61tbVIT0/H1KlTFZNARpqR8AEMDg5GQkICEhISAAB9fX2yIDx9+jTsdjuioqLkCGFUVNQVEYTuotBdEPb19cnHSIJQihCSICQIwl9IABIEoUCKOvX39yMoKMhvUdHf34/y8nLYbDbk5OQgOjpasX+kjZtHS+iEhoYiMTERiYmJsuCSBOGFCxfgdDoVU0quxHxjfwWh5EMYHBxMgpAgCF1IABIEISOlfM+dO4fm5mYsXLjQLwHR2tqK8vJyxMbGYsGCBao1dHoC0J+GEq31jiYMw8BoNMJoNGLixIkQRRG9vb2yIDx37hx4ngfDMGhoaLjiY+sApSDs7OxEeXk5cnNz5S5k9xpCEoQEQUiQACQIAoC3tx/gO8omCAJqamrQ0NCAWbNmYdKkSX5NAhkJRnsUnNY9w8PDER4ejsmTJ0MURZw7dw6NjY3o6OjwGltnMpkQHh5+RecYS1FAqXmnv78ffX19YFnWq6mEBCFBjF1IABLEGMfd20/y1ZOmW+hhs9lQVlYGp9OJ3NxcRERE6B7/WQi20UaKEIaEhGD+/Pny2Dqr1Yq2tjZ5bJ27IBzNsXXSJBXPiKoUIeR5HjzPa9rOkCAkiLEDCUCCGMMIggCn0+nV5etLrDU3N6OyshIJCQmYNWuWX00Rn4cIoC/cx9ZNnTpVHltnsVjQ0tKCmpoaxdg6k8kEo9E44mtQ2yYJPkBbEEopY/c5xiQICeLzCQlAghiDuHv7uUeNJBiGUY0A8jyPM2fO4MKFC5gzZw4SExP9vufVKNhGG2lsndQQw/M8Ojs7FWPrQkJCFBHC4Rhl+/v8aglCp9MJh8Mh71ebY0yCkCA+H5AAJIgxhqe3n9qbOsuyXmKip6cHpaWlYFkWeXl5CAsLC+i+vgRgoMLiWhSUHMd5ja2TTKmbmppw6tQpGI3GIY+tk8R8oAQiCCUfQillTBDEtQk
"text/html": [
"\n",
" <div style=\"display: inline-block;\">\n",
" <div class=\"jupyter-widgets widget-label\" style=\"text-align: center;\">\n",
" Figure\n",
" </div>\n",
" <img src='data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAoAAAAHgCAYAAAA10dzkAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOydeXwU9f3/XzOzOTb3JoEQIEAubkhIIJAEEAhKtVK/tPi1apUKtn7tt2jV1mqrtR61v2o9Wmtrv61ptYc9rIhUC0oAucIZchEgm4MEAoQku7k32d2Z+f2xmcnO7szsbg4E834+HlEy52c3ye5r38frzYiiKIIgCIIgCIIYM7Cf9QIIgiAIgiCIKwsJQIIgCIIgiDEGCUCCIAiCIIgxBglAgiAIgiCIMQYJQIIgCIIgiDEGCUCCIAiCIIgxBglAgiAIgiCIMQYJQIIgCIIgiDEGCUCCIAiCIIgxBglAgiAIgiCIMQYJQIIgCIIgiDEGCUCCIAiCIIgxBglAgiAIgiCIMQYJQIIgCIIgiDEGCUCCIAiCIIgxBglAgiAIgiCIMQYJQIIgCIIgiDEGCUCCIAiCIIgxBglAgiAIgiCIMQYJQIIgCIIgiDEGCUCCIAiCIIgxBglAgiAIgiCIMQYJQIIgCIIgiDEGCUCCIAiCIIgxBglAgiAIgiCIMQYJQIIgCIIgiDEGCUCCIAiCIIgxBglAgiAIgiCIMQYJQIIgCIIgiDEGCUCCIAiCIIgxBglAgiAIgiCIMQYJQIIgCIIgiDEGCUCCIAiCIIgxBglAgiAIgiCIMQYJQIIgCIIgiDEGCUCCIAiCIIgxBglAgiAIgiCIMQYJQIIgCIIgiDEGCUCCIAiCIIgxBglAgiAIgiCIMQYJQIIgCIIgiDEGCUCCIAiCIIgxBglAgiAIgiCIMQYJQIIgCIIgiDEGCUCCIAiCIIgxBglAgiAIgiCIMQYJQIIgCIIgiDEGCUCCIAiCIIgxBglAgiAIgiCIMQYJQIIgCIIgiDEGCUCCIAiCIIgxBglAgiAIgiCIMQYJQIIgCIIgiDEGCUCCIAiCIIgxBglAgiAIgiCIMQYJQIIgCIIgiDEGCUCCIAiCIIgxBglAgiAIgiCIMQYJQIIgCIIgiDEGCUCCIAiCIIgxBglAgiAIgiCIMQYJQIIgCIIgiDEGCUCCIAiCIIgxBglAgiAIgiCIMQYJQIIgCIIgiDEGCUCCIAiCIIgxBglAgiAIgiCIMQYJQIIgCIIgiDEGCUCCIAiCIIgxBglAgiAIgiCIMQYJQIIgCIIgiDEGCUCCIAiCIIgxBglAgiAIgiCIMQYJQIIgCIIgiDEGCUCCIAiCIIgxBglAgiAIgiCIMQYJQIIgCIIgiDEGCUCCIAiCIIgxBglAgiAIgiCIMQYJQIIgCIIgiDEGCUCCIAiCIIgxBglAgiAIgiCIMQYJQIIgCIIgiDEGCUCCIAiCIIgxhuGzXgBBEIQoiuB5HgDAcRwYhvmMV0QQBPH5hgQgQRCfKYIgwOFwwGazQRRFsCyLoKAgcBwHg8EAlmVJEBIEQYwwjCiK4me9CIIgxh5S1M/pdMoiUHo5EgQBAMAwjCwIDQYDOI4jQUgQBDECkAAkCOKKI4oiHA6HnPZlGAZ2u13+t3SM9EWCkCAIYmQhAUgQxBVFEATY7XYIgiCLN1EUvQSgJ56CsL6+HvHx8TCZTDAYDCQICYIgAoBqAAmCuCJIKV8p1esp1CQhqAXDMPLxHMeho6MD0dHREAQB/f396OvrA8uyYFmWBCFBEIQPSAASBDHqCIIAi8WC7u5uJCQkjIgokwQhx3EABiOEPM+D53n09/fLKWMShARBEEpIABIEMWpI6VqHwwGLxYKWlhYkJiaO6PUlJEHIsqy8z73RxF0wSmLQYDAoIosEQRBjBRKABEGMCqIowul0wul0AgBYltVN8QaKvyljT0HodDrhcDi8BKEkCkkQEgQxFiABSBDEiCNF/Xiel0WYL8E22gQiCCUfQillTBAE8XmDBCBBECOGp7efe72dngC02Wwwm80IDQ2FyWRCVFSUT+E1XEHpSxAC8KofJEFIEMTnBRKABEGMCJ7efv52+TY3N6OiogJxcXHo7u7G+fPnIQgCoqOjYTKZYDKZEBkZOeppWS1B6HA4ZIsaEoQEQXxeIAFIEMSwUfP288RTAPI8jzNnzuDChQuYM2cO4uLi5P09PT2wWq2wWq1oaGgAAMTExMBkMiEmJgYRERGjnlJWE4RSaluKEDIMQ4KQIIhrEhKABEEMGV/efu64C7aenh6UlpaCZVnk5eXBaDQqRFVERAQiIiKQlJQEURTR1dWF9vZ2WCwW1NXVySIrODgY0dHRCAsLuyIRQslyRnrsgiCgubkZjY2NyMzM9BKEUpcxQRDE1QYJQIIghoSvlK8n0r6mpiZUVVUhKSkJ06dPV3QHi6LodQ2GYRAVFYWoqChMmTIFgiCgq6sLVVVV6OrqwtGjR2EwGOR0cUxMDIxG4xUThCzLQhAEcBwnRwjtdrscPVTrMiYIgvisIQFIEETASFE/vZSvJ4IgoLe3F6dPn0ZmZibGjRs3pHuzLCtH/eLj45GYmIiOjg5YrVZcvHgRZ86cQXBwsCwITSYTQkNDh3SvQNCKEHoKQs+UMQlCgiA+C0gAEgThN2refv4ImM7OTpw+fRqiKGLp0qUjJsiktLMk9ACXOJUEYVNTE06fPi13F0tfwcHBI3J/PTynlACDtZJaU0pIEBIEcaUgAUgQhF9I0SxBEAAoZ/NqIYoiGhsbUV1djYSEBFit1hETf1r35jgOsbGxiI2NBQA4nU60t7fLDSUnT55EeHi4nC42mUwICgoakTX5WisJQoIgrhZIABIEoYt7KjOQlK/D4UBlZSXa29uRnZ0NURTR3t4+4mvzhcFgQHx8POLj4+V1SYKwvr4elZWViIiIUNQQGgz+vzQOpRNZTxD29/fr2s6QICQIYiQgAUgQhCaBNnpIWK1WlJWVITIyEvn5+QgODobFYpGjhyPBUIVQUFAQxo0bJ9cg2u122XLGbDajr68PkZGRshiMiYlR1PaNBu6CkOM42YNQFEUvQShNKTEYDH7/PAiCIDwhAUgQhCru49z8FRqiKKK+vh61tbVIT0/H1KlTFZNARpqR8AEMDg5GQkICEhISAAB9fX2yIDx9+jTsdjuioqLkCGFUVNQVEYTuotBdEPb19cnHSIJQihCSICQIwl9IABIEoUCKOvX39yMoKMhvUdHf34/y8nLYbDbk5OQgOjpasX+kjZtHS+iEhoYiMTERiYmJsuCSBOGFCxfgdDoVU0quxHxjfwWh5EMYHBxMgpAgCF1IABIEISOlfM+dO4fm5mYsXLjQLwHR2tqK8vJyxMbGYsGCBao1dHoC0J+GEq31jiYMw8BoNMJoNGLixIkQRRG9vb2yIDx37hx4ngfDMGhoaLjiY+sApSDs7OxEeXk5cnNz5S5k9xpCEoQEQUiQACQIAoC3tx/gO8omCAJqamrQ0NCAWbNmYdKkSX5NAhkJRnsUnNY9w8PDER4ejsmTJ0MURZw7dw6NjY3o6OjwGltnMpkQHh5+RecYS1FAqXmnv78ffX19YFnWq6mEBCFBjF1IABLEGMfd20/y1ZOmW+hhs9lQVlYGp9OJ3NxcRERE6B7/WQi20UaKEIaEhGD+/Pny2Dqr1Yq2tjZ5bJ27IBzNsXXSJBXPiKoUIeR5HjzPa9rOkCAkiLEDCUCCGMMIggCn0+nV5etLrDU3N6OyshIJCQmYNWuWX00Rn4cIoC/cx9ZNnTpVHltnsVjQ0tKCmpoaxdg6k8kEo9E44mtQ2yYJPkBbEEopY/c5xiQICeLzCQlAghiDuHv7uUeNJBiGUY0A8jyPM2fO4MKFC5gzZw4SExP9vufVKNhGG2lsndQQw/M8Ojs7FWPrQkJCFBHC4Rhl+/v8aglCp9MJh8Mh71ebY0yCkCA+H5AAJIgxhqe3n9qbOsuyXmKip6cHpaWlYFkWeXl5CAsLC+i+vgRgoMLiWhSUHMd5ja2TTKmbmpp
" </div>\n",
" "
],
"text/plain": [
"Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from visualize_outputs import sample2d, sample3d, plot2d, plot3d\n",
"from onnx2pytorch import convert\n",
"\n",
"%matplotlib widget\n",
"\n",
"n_sample = 1000 # number of points to sample\n",
"frozen_dim = 0 # which dimension will have a constant value for the 2d plot, (0: distance, 1: speed, 2: angle)\n",
"frozen_val = 0.9 # constant value to give to the frozen dimension\n",
"model = convert(\"network.onnx\")\n",
"dim_1, dim_2, colours = sample2d(model, n_sample, frozen_dim, frozen_val)\n",
"plot2d(dim_1, dim_2, colours)\n",
"dim_1, dim_2, dim_3, colours = sample3d(model, n_sample)\n",
"plot3d(dim_1, dim_2, dim_3, colours)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"On this plot we can already start to see some tendencies of alarms. As always the difficult part will be at the decision boundary."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Step 2: Write some safety property\n",
"\n",
"Let us say we would like to check the following property:\n",
"\n",
"\"No input above a distance .95 output an alarm\". \n",
"\n",
"If our neural network correctly follows our specification, it should respect this property (since the distance is above 0.7, no alarm should be issued).\n",
"\n",
"The basics for formulating properties to our various tools is the following:\n",
"\n",
"1. Write constraints on inputs as defined in the specification\n",
"1. State the contraint on outputs you want to check\n",
"1. For tools using SMT, write down the negation of this constraint\n",
" (remember in the course section, VALID(F) is equivalent to ¬SAT(¬F)). \n",
"\n",
"For this tutorial, we wrote a simple API to make easier the writing of properties to check. This API is detailed inside `formula_lang.py` in the repository.\n",
"This API allows to define linear constraints on symbolic variables and real values.\n",
"* To define a new variable, use the constructor `Var(x)`, where `x` must be a string\n",
"* To define a new real number, use `Real(r)` where `r` is a python number\n",
"* A constraint is a linear inequality between two variables or reals. To create a new constraint, use\n",
" `constr = Constr(var1, bop, var2)` where `bop` is either `'>='` or `<`.\n",
"* You can create multiple constraints\n",
"* Finally, once you are satisfied, you can add your constraints to a formula. A formula is a conjunction of constraints\n",
"* `f = Formula()` creates an empty constraint, and `f.add(constr)` add the constraint `constr` to the formula.\n",
" `f.add(c1)` followed by `f.add(c2)` is equivalent to adding a conjunction of `c1` and `c2`\n",
"\n",
"\n",
"Here is how to use it:\n",
"#### Variables creations\n",
"1. Create a new variable with `var = Var(str)`; `str` should be \n",
" either `'x0'`, `'x1'`, `'x2'` or `'y0'`, respectively the first, \n",
" second and third input and only output. For convenience, they \n",
" are already defined when executing the cell below as\n",
" `distance`, `speed`, `angle` and `output`\n",
" \n",
"2. Create a new real value with `real = Real(r)` where `r`\n",
" can be an integer or a float (all variables will be converted\n",
" as real values). For instance, `real = Real(0.95)`\n",
" \n",
" \n",
"#### Creating constraints and adding them to a formula\n",
"1. Create a new constraint between a variable `var` and a\n",
" real value `real` with `constr = Constr(var, bop, real)` where\n",
" `bop` is either `'>='` or `'<'`. For instance, `constr = Constr(distance,'>=',real)`\n",
"2. Create a new empty formula with `f = Formula()`\n",
"3. Add a constraint `constr` to a formula `f` with `f.add(constr)`. \n",
" \n",
"#### Printing and saving to disk\n",
"1. Print a formula `f` with `print(f)`\n",
"2. Write down a formula `f` to SMTLIB2 format at destination `dest`\n",
" with `f.write_smtlib(dest)`; similarly for Marabou format,\n",
" use `f.write_marabou(dest)` or PyRAT format `f.write_pyrat(dest)`.\n",
" For smtlib or marabou the negation of the output is done automatically.\n",
"\n",
"A simple example is given below for the property \"No input above a distance .95 output an alarm\"."
]
},
{
"cell_type": "code",
"execution_count": 69,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"x0 >= 0.95; x0 < 1; x1 >= 0.95; x1 < 1; x2 >= 0.95; x2 < 1; y0 < 0; \n",
"Wrote SMT formula in file formula.smt2\n",
"Wrote marabou formula in file formula.marabou\n",
"Wrote pyrat formula in file formula.txt\n"
]
}
],
"source": [
"from formula_lang import *\n",
"\n",
"distance = Var('x0')\n",
"speed = Var('x1')\n",
"angle = Var('x2')\n",
"\n",
"output = Var('y0')\n",
"\n",
"one = Real(1)\n",
"real = Real(0.95)\n",
"zero = Real(0)\n",
"\n",
"constrs = []\n",
"constrs.append(Constr(distance, '>=', real))\n",
"constrs.append(Constr(distance, '<', one))\n",
"constrs.append(Constr(speed, '>=', real))\n",
"constrs.append(Constr(speed, '<', one))\n",
"constrs.append(Constr(angle, '>=', real))\n",
"constrs.append(Constr(angle, '<', one))\n",
"constrs.append(Constr(output, '<', zero))\n",
"\n",
"formula = Formula()\n",
"for c in constrs:\n",
" formula.add(c)\n",
" \n",
"print(formula)\n",
" \n",
"formula.write_smtlib()\n",
"formula.write_marabou()\n",
"formula.write_pyrat()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Launch solvers and recover results\n",
"\n",
"As mentioned, we will use three tools in this tutorial:\n",
"\n",
"1. Z3, a theorem prover from Microsoft Research [https://github.com/Z3Prover/z3](https://github.com/Z3Prover/z3), used as a state-of-the-art SMT solver; however it does not have any particular heuristics to work on neural networks \n",
"2. PyRAT, a tool internally developped at the lab, that leverages abstract interpretation to verify reachability properties on neural networks. The source is currently not available, if you want to access it just send us an email\n",
"3. Marabou, a solver tailored for neural network verification: it uses a specialized Simplex algorithm and merges relevant neurons. See [the paper](https://arxiv.org/abs/1910.14574) for more details\n",
"\n",
"You will notice that PyRAT performs a \"reachability analysis\" (given the input range, what is the possible output range?). Marabou does not deal with disjunction of clauses ($a< x1 <b \\vee c < x1 <d$), so you will need to formulate the two clauses in separate properties (one with $a< x1 <b$, one with $c < x1 <d$).\n",
"\n",
"It is partly due to implementation constraints, and on such simple problem this should not be a limitation. But the set of expressible properties is different between abstract interpretation and SAT/SMT calculus.\n",
"\n",
"Here is a recap about the tools we will be using:\n",
"\n",
"| | Z3 | Marabou | PyRAT \t|\n",
"|----------------------------------\t|--------------------------------\t|--------------------------------------\t|-------------------------\t|\n",
"| input format \t| SMTLIB \t| Specific \t| Specific \t|\n",
"| technology \t| SMT \t| SMT / overapproximation \t| abstract interpretation \t|\n",
"| specialized for neural networks \t| no \t| yes \t| yes \t|"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Step 1: Z3 the classical SMT solver\n",
"\n",
"As mentioned, Z3 is not made for neural networks specifically, we are simply transforming the network as a classical problem for the solver to handle. For this, we will thus need to transform the network to the SMT format first before launching the tool. This is done in `call_isaeih` using the open-source ISAEIH tool developed at CEA. The `launch_z3` function will then call the solver directly on the transformed network"
]
},
{
"cell_type": "code",
"execution_count": 70,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import time\n",
"\n",
"def call_isaieh(fpath):\n",
" \"\"\"Convert an ONNX network at fpath to a SMT formula describing;\n",
" the control flow. The output will be called fpath_QF_NRA.smt2\"\"\"\n",
" !./bin/isaieh.exe -theory=QF_NRA $fpath"
]
},
{
"cell_type": "code",
"execution_count": 71,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/bin/bash: ./bin/isaieh.exe: Permission non accordée\r\n"
]
}
],
"source": [
"call_isaieh('network.onnx')"
]
},
{
"cell_type": "code",
"execution_count": 72,
"metadata": {},
"outputs": [],
"source": [
"def launch_z3(control_flow, constraints, verbose=True):\n",
" \"\"\"Launch z3 on the SMT problem\n",
" composed of the concatenation of a control flow and constraints, both written in SMTLIB2\"\"\"\n",
" !cat $control_flow $constraints > z3_input\n",
" \n",
" t = time.perf_counter()\n",
" output = !z3 z3_input\n",
" output = \"\\n\".join(output)\n",
" if verbose:\n",
" print(output)\n",
" \n",
" return \"unsat\" in output, time.perf_counter() - t "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To launch Z3, provide the filepath of the network's control flow in SMTLIB as well as the filepath of the\n",
"formula you generated. \n",
"\n",
"A **SAT** result will be followed by an instanciation of the input and outputs that satisfies the negation of your property, i.e. a counter-example. \n",
"\n",
"An **UNSAT** result signifies that it could not find such a counter example and that your property holds."
]
},
{
"cell_type": "code",
"execution_count": 73,
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"cat: network_QF_NRA.smt2: Aucun fichier ou dossier de ce type\n",
"/bin/bash: z3 : commande introuvable\n"
]
}
],
"source": [
"res, t_z3 = launch_z3('network_QF_NRA.smt2','formula.smt2')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Step 2: Marabou the simplex-based solver\n",
"\n",
"As opposed to Z3, Marabou is tailor-made for neural network. It uses its own property format and can only read network under the NNet format. We already provide a copy of the ONNX network in NNET that was made using a converter `network.nnet`."
]
},
{
"cell_type": "code",
"execution_count": 74,
"metadata": {},
"outputs": [],
"source": [
"def launch_marabou(network, constraints, verbose=True):\n",
" \"\"\"Launch marabou on the verification problem:\n",
" network is a .nnet description of the network (provided here\n",
" for simplicity) and constraints is a property file you wrote\"\"\"\n",
" t = time.perf_counter()\n",
" output = !./bin/marabou.elf --timeout=100 $network $constraints\n",
" output = \"\\n\".join(output)\n",
" if verbose:\n",
" print(output)\n",
" return \"unsat\" in output.lower(), time.perf_counter() - t "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Similarly for Marabou, provide the filepath of model and the property you generated in the Marabou format. The output of Marabou is more verbose dans Z3, the important part lies at the end: SAT or UNSAT. "
]
},
{
"cell_type": "code",
"execution_count": 75,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/bin/bash: ./bin/marabou.elf: Permission non accordée\n"
]
}
],
"source": [
"res, t_marabou = launch_marabou(\"network.nnet\", 'formula.marabou')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Step 3: PyRAT with Abstract Interpration\n",
"\n",
"PyRAT API provide directly a function to launch the analysis `launch_pyrat`. Remember that a _negative value_ means\n",
"no alarm issued, while a _positive value_ means that an alarm is issued. \n",
"\n",
"PyRAT can directly work on the ONNX network while its input format is similar to the Marabou with only the difference of having the property to prouve instead of its negation.\n",
"- **True** means the property holds,\n",
"- **False** means it does not.\n",
"\n",
"PyRAT will also output the bounds reached at the end of the analysis."
]
},
{
"cell_type": "code",
"execution_count": 76,
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Output bounds:\n",
" [-6.0188885]\n",
"[-5.915209]\n",
"Result = True, Time = 0.00 s\n"
]
}
],
"source": [
"from pyrat_api import launch_pyrat\n",
"\n",
"res, t_pyrat = launch_pyrat(\"network.onnx\", \"formula.txt\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"With PyRAT you can also use different abstract domains to get a more precise results, avalaible domains are \"poly\", \"zono\", \"symbox\". While they increase precision they also increase computation time. Nevertheless, unless we are in such simple setting the interval by themselves are often too imprecise to conclude.\n",
"\n",
"An additional argument `split_timeout` can also help to increase the precision of the analysis. Try it for now and we will see it in more details in the next section."
]
},
{
"cell_type": "code",
"execution_count": 77,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Result = True, Time = 0.01 s, Safe space = 100.00 %, number of analysis = 1\n"
]
}
],
"source": [
"res, t_pyrat = launch_pyrat(\"network.onnx\", \"formula.txt\", domains=[\"zono\"], split_timeout=10)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Moving to bigger properties\n",
"\n",
"You were able to launch solvers to verify this simple property. We can already notice that Z3 took significantly more time than the others to return a result: it comes from the suboptimal encoding of the problem as well as the lack of heuristics tailored to neural network verification. Using Z3 on more complex properties will likely hang your session; don't hesitate to terminate the cell's execution if it takes too much time.\n",
"\n",
"Even so, with all three tools we managed to prove the simple property we created. But this property was not the one we initially mentioned. We should aim to answer the following questions:\n",
"- With the setting of $\\alpha = 0.5$, $\\beta = 0.25$, $\\delta_1 = 0.3$, $\\delta_2 = 0.7$ can we prove our three initial properties on this network ?\n",
"- If not for what values of $\\alpha, \\beta, \\delta_1$ and $\\delta_2$ do they hold ?\n",
"\n",
"Initial properties of the zones:\n",
"1. a **”safe”** zone: when B is in this zone, it is not considered a threat for any $||\\vec{d}|| > \\delta_2$, no ALARM is issued.\n",
"2. a **”suspicious”** zone: when B is in this zone, if $||\\vec{v}|| > \\alpha$ and $\\theta < \\beta$\n",
" then a ALARM should be issued. Else, no ALARM is issued.\n",
"3. a **”danger”** zone: when B is in this zone, a ALARM is issued no matter what. When $||\\vec{d}|| < \\delta_1$, B is in the danger zone.\n",
"\n",
"**Write these properties and try proving them with the different tools**. "
]
},
{
"cell_type": "code",
"execution_count": 154,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"x0 >= 0.7; x0 < 0.3; x1 >= 0.4; x1 < 1; x2 < 0.25; x2 >= 0; \n",
"Wrote pyrat formula in file safe.txt\n",
"Wrote pyrat formula in file suspicious.txt\n",
"Wrote pyrat formula in file danger.txt\n"
]
}
],
"source": [
"from formula_lang import * \n",
"\n",
"distance = Var('x0')\n",
"speed = Var('x1')\n",
"angle = Var('x2')\n",
"\n",
"output = Var('y0')\n",
"\n",
"one = Real(1)\n",
"delta_1 = Real(0.2)\n",
"delta_2 = Real(0.78)\n",
"alpha = Real(1)\n",
"beta = Real(0.01)\n",
"zero = Real(0)\n",
"\n",
"# Safe\n",
"constrs = []\n",
"constrs.append(Constr(distance, '>=', delta_2))\n",
"constrs.append(Constr(distance, '<', one))\n",
"constrs.append(Constr(speed, '<', zero))\n",
"constrs.append(Constr(speed, '>=', zero))\n",
"constrs.append(Constr(angle, '>=', zero))\n",
"constrs.append(Constr(angle, '<', one))\n",
"constrs.append(Constr(output, '<', zero))\n",
"\n",
"formula_safe = Formula()\n",
"for c in constrs:\n",
" formula_safe.add(c)\n",
" \n",
"# Suspicious\n",
"constrs = []\n",
"constrs.append(Constr(distance, '>=', zero))\n",
"constrs.append(Constr(distance, '<', one))\n",
"constrs.append(Constr(speed, '>=', alpha))\n",
"constrs.append(Constr(speed, '<', one))\n",
"constrs.append(Constr(angle, '>=', zero))\n",
"constrs.append(Constr(angle, '<', beta))\n",
"constrs.append(Constr(output, '>=', zero))\n",
"\n",
"formula_sus = Formula()\n",
"for c in constrs:\n",
" formula_sus.add(c)\n",
" \n",
" \n",
"# Danger\n",
"constrs = []\n",
"constrs.append(Constr(distance, '>=', zero))\n",
"constrs.append(Constr(distance, '<', delta_1))\n",
"constrs.append(Constr(speed, '>=', zero))\n",
"constrs.append(Constr(speed, '<', one))\n",
"constrs.append(Constr(angle, '>=', zero))\n",
"constrs.append(Constr(angle, '<', one))\n",
"constrs.append(Constr(output, '>=', zero))\n",
"\n",
"formula_danger = Formula()\n",
"for c in constrs:\n",
" formula_danger.add(c)\n",
" \n",
"print(formula)\n",
" \n",
"formula_safe.write_pyrat(\"safe.txt\")\n",
"formula_sus.write_pyrat(\"suspicious.txt\")\n",
"formula_danger.write_pyrat(\"danger.txt\")"
]
},
{
"cell_type": "code",
"execution_count": 155,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Result = True, Time = 0.01 s, Safe space = 100.00 %, number of analysis = 1\n",
"Result = False, Time = 0.04 s, Safe space = 75.00 %, number of analysis = 7\n",
"Result = True, Time = 0.07 s, Safe space = 100.00 %, number of analysis = 11\n"
]
}
],
"source": [
"res, t_pyrat = launch_pyrat(\"network.onnx\", \"safe.txt\", domains=[\"zono\"], split_timeout=10)\n",
"res, t_pyrat = launch_pyrat(\"network.onnx\", \"suspicious.txt\", domains=[\"zono\"], split_timeout=10)\n",
"res, t_pyrat = launch_pyrat(\"network.onnx\", \"danger.txt\", domains=[\"zono\"], split_timeout=10)\n",
"#res, t_marabou = launch_marabou(\"network.nnet\", 'formula.marabou')\n",
"#res, t_z3 = launch_z3('network_QF_NRA.smt2','formula.smt2')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"*****\n",
"\n",
"## Part 3: A real problem ACAS-Xu\n",
"\n",
"Let's tackle a more complex network of the public benchmark of Airborne Collision Avoidance System for Unmanned vehicles (ACAS-Xu). For an introduction of the ACAS-Xu benchmark, see _An Introduction to ACAS Xuand the Challenges Ahead , Manfredi G., Jestin Y._\n",
"\n",
"Formerly, programs such as ACAS were implemented as a lookup table. The neural network approach was partly introduced to reduce the memory requirements for hardware of the original 2GB tables down to 3MB. Neural network becoming more and more efficient, industrials considered using them as a possible replacement. Since we are dealing with critical systems, software safety is a major concern.\n",
"\n",
"In _Reluplex: An Efficient SMT Solver for Verifying Deep Neural Networks_ , Katz et al. provided a neural network implementation of a ACAS-Xu system, as well as a tooling to formally verify several safety properties. It was then adopted as a common benchmark in the literature.\n",
"\n",
"The problem is similar to our previous one, a neural network needs to give direction advisory to a drone when an intruder is detected at proximity. The advisory can be \"Clear of Conflict\" meaning there is no risk of collision, \"Right\", \"Strong Right\", \"Left\", \"Strong Left\" for evasion direction. The minimum score will correspond to the decision taken.\n",
"\n",
"![acas](imgs/acas.png)\n",
"\n",
"The network takes 5 inputs:\n",
"- ρ (m): Distance from ownship to intruder.\n",
"- θ (rad): Angle to intruder relative to ownship heading\n",
"direction.\n",
"- ψ (rad): Heading angle of intruder relative to ownship\n",
"heading direction.\n",
"- $v_{own}$ (m/s): Speed of ownship.\n",
"- $v_{int}$ (m/s) Speed of intruder.\n",
"\n",
"We want to prove one property on this network: \n",
"\n",
"**Property 1.**\n",
"- Description: If the intruder is distant and is significantly slower than the\n",
"ownship, the score of a COC advisory will always be below a certain fixed\n",
"threshold.\n",
"- Input constraints: ρ ≥ 55947.691, $v_{own} ≥ 1145$, $v_{int} ≤ 60$.\n",
"- Desired output property: the score for COC is at most 1500.\n",
"\n",
"The formulas for this property are already written. Normalised values will be given below."
]
},
{
"cell_type": "code",
"execution_count": 156,
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Wrote pyrat formula in file formula_p1.txt\n",
"Wrote marabou formula in file formula_p1.marabou\n",
"Output bounds:\n",
" [-10710.81020574 -9110.01167952 -12047.17947303 -2730.80270361\n",
" -13194.04888676]\n",
"[ 5272.51746506 5429.67755395 4817.41682782 10864.69018591\n",
" 5844.88081043]\n",
"Result = Unknown, Time = 0.01 s\n"
]
}
],
"source": [
"from formula_lang import formula_p1\n",
"from pyrat_api import launch_pyrat\n",
"\n",
"formula_p1().write_pyrat(\"formula_p1.txt\")\n",
"formula_p1().write_marabou(\"formula_p1.marabou\")\n",
"\n",
"res, t_pyrat = launch_pyrat(\"acas.nnet\", \"formula_p1.txt\", domains=[\"zono\"])\n",
"#res, t_marabou = launch_marabou(\"acas.nnet\", 'formula_p1.marabou') # does not finish"
]
},
{
"cell_type": "code",
"execution_count": 158,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Result = True, Time = 2.97 s, Safe space = 100.00 %, number of analysis = 281\n"
]
}
],
"source": [
"# splitting the inputs\n",
"res, t_pyrat = launch_pyrat(\"acas.nnet\", \"formula_p1.txt\", domains=[\"zono\"], split_timeout=100)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can see here that PyRAT performs 281 analysis on this property in total. Indeed after failing to prove the property at the begining it divides the input space and perform subsequent analysis until it proves the property. At the same time we see that Marabou takes too long to prove the property so we might want to accelerate this.\n",
"\n",
"In this part, we will aim to facilitate the analysis similar to the `split_timeout` option of PyRAT. For this we will divide the input space into smaller parts before performing the analysis on smaller parts.\n",
"\n",
"### Step 1: Divide a formula\n",
"\n",
"We will first look at how to divide the input space for our problem. As we worked until now on Formulas it might not be evident to divide them as they are mostly in a textual form. We will thus come back to our Interval class, which we can easily divide, then transform an Interval into a constraint that we can add into a Formula.\n",
"\n",
"**Write the following functions.** `create_formula_p1` is already written and should properly create a formula if `interval_to_constraint` is correctly implemented."
]
},
{
"cell_type": "code",
"execution_count": 169,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"def divide_interval(x: Interval) -> (Interval, Interval):\n",
" if isinstance(x, Interval):\n",
" return Interval(x.lower, (x.upper+x.lower)/2), Interval((x.upper+x.lower)/2, x.upper)\n",
" raise NotImplementedError\n",
"\n",
"assert divide_interval(Interval(0, 1)) == (Interval(0, 0.5), Interval(0.5, 1))\n",
"assert divide_interval(Interval(-5, 1)) == (Interval(-5, -2), Interval(-2, 1))"
]
},
{
"cell_type": "code",
"execution_count": 171,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"x0 >= 0.6; x0 < 0.6798577687; x1 >= -0.5; x1 < 0.5; x2 >= -0.5; x2 < 0.5; x3 >= 0.45; x3 < 0.5; x4 >= -0.5; x4 < 0.45; y0 < 3.9911256459; \n"
]
}
],
"source": [
"from typing import List\n",
"\n",
"def interval_to_constraint(x: Interval, name: str) -> (Constr, Constr):\n",
" if isinstance(x, Interval):\n",
" return Constr(Var(name), '>=', Real(x.lower)), Constr(Var(name), '<', Real(x.upper))\n",
" \n",
"\n",
"def create_formula_p1(inputs: List[Interval]) -> Formula:\n",
" output = Var('y0')\n",
"\n",
" constrs = []\n",
" for i in range(len(inputs)):\n",
" constrs.extend(interval_to_constraint(inputs[i], f\"x{i}\")) # for each input interval we create and add the constraint\n",
" \n",
" constrs.append(Constr(output, '<', Real(3.9911256459))) # constraint on the output\n",
"\n",
" formula = Formula()\n",
" for c in constrs:\n",
" formula.add(c)\n",
"\n",
" return formula\n",
"\n",
"initial = [Interval(0.6, 0.6798577687), Interval(-0.5, 0.5), Interval(-0.5, 0.5), Interval(0.45, 0.5), Interval(-0.5, 0.45)]\n",
"print(create_formula_p1(initial))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Step 2: Iteration algorithm\n",
"\n",
"Now we can generate formula from intervals and divide the intervals, we can start working on an algorithm to divide the input space for the analysis. The general idea of the algorithm would be:\n",
"\n",
"1. Initial: 1 Interval per input as defined by the property\n",
"2. Create a formula from the intervals\n",
"3. Run an analysis with eith PyRAT or Marabou (Marabou has a timeout of 100 it can be changed in `launch_marabou`)\n",
"4. If property is verified stop. \n",
" Otherwise, divide an interval in 2 and come back to step 2 for both subspaces created\n",
"5. If all subspaces are proven True/unsat we can conclude that the initial space is as well.\n",
"\n",
"`launch_pyrat` and `launch_marabou` both return a tuple of `boolean, float` the boolean indicates if the property holds when it equals `True`. The float is the time taken for the analysis. Here do not use `split_timeout = 0` for PyRAT not to do any splitting on its own."
]
},
{
"cell_type": "code",
"execution_count": 182,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Wrote pyrat formula in file formula.txt\n",
"Output bounds:\n",
" [-13703.76840556 -11650.49647462 -15399.48471075 -3491.64369193\n",
" -16873.93784018]\n",
"[ 6741.31879213 6948.28653093 6161.77284022 13889.18982366\n",
" 7474.49709356]\n",
"Result = Unknown, Time = 0.01 s\n",
"Wrote pyrat formula in file formula.txt\n",
"Output bounds:\n",
" [-3550.55339166 -3023.83749887 -4002.8480418 -902.79341163\n",
" -4383.57122813]\n",
"[1745.18966127 1794.98362025 1593.07387949 3605.36711416 1933.38484528]\n",
"Result = Unknown, Time = 0.01 s\n",
"Wrote pyrat formula in file formula.txt\n",
"Output bounds:\n",
" [-55.06820859 -46.94232084 -65.63917304 -19.87184717 -73.61251578]\n",
"[37.9098227 37.86804778 35.40315467 70.01284302 43.81451971]\n",
"Result = Unknown, Time = 0.01 s\n",
"Wrote pyrat formula in file formula.txt\n",
"Output bounds:\n",
" [-0.02121775 -0.01879782 -0.01952425 -0.01893638 -0.0185982 ]\n",
"[-0.01423817 -0.01179154 -0.00823094 -0.01761423 -0.00731985]\n",
"Result = True, Time = 0.01 s\n",
"Wrote pyrat formula in file formula.txt\n",
"Output bounds:\n",
" [-0.0207541 -0.01877377 -0.01951269 -0.01891985 -0.01858325]\n",
"[-0.02016294 -0.01871954 -0.01948663 -0.01888259 -0.01854955]\n",
"Result = True, Time = 0.01 s\n",
"Wrote pyrat formula in file formula.txt\n",
"Output bounds:\n",
" [-352.70570723 -303.42657267 -401.7464177 -90.58138666 -436.26610886]\n",
"[172.19825764 173.99488784 157.01285023 362.83767465 192.71922046]\n",
"Result = Unknown, Time = 0.01 s\n",
"Wrote pyrat formula in file formula.txt\n",
"Output bounds:\n",
" [-0.15568065 -0.1716002 -0.20014387 -0.699908 -0.27070714]\n",
"[0.92424944 1.04669573 1.47032404 0.97590631 1.82164312]\n",
"Result = True, Time = 0.01 s\n",
"Wrote pyrat formula in file formula.txt\n",
"Output bounds:\n",
" [ -7.47190445 -6.81427319 -10.90876595 -6.29717939 -14.23650304]\n",
"[12.45968295 12.90411493 12.33305946 16.50971771 15.35373532]\n",
"Result = Unknown, Time = 0.01 s\n",
"Wrote pyrat formula in file formula.txt\n",
"Output bounds:\n",
" [-0.02050728 -0.01878764 -0.01951936 -0.01892939 -0.01859188]\n",
"[-0.0042468 0.00676165 0.05240967 0.00147856 0.06137028]\n",
"Result = True, Time = 0.01 s\n",
"Wrote pyrat formula in file formula.txt\n",
"Output bounds:\n",
" [-0.05395529 -0.07882239 -0.0844974 -0.09567312 -0.23208993]\n",
"[0.10340051 0.15111863 0.11976172 0.14357011 0.18966976]\n",
"Result = True, Time = 0.01 s\n",
"Wrote pyrat formula in file formula.txt\n",
"Output bounds:\n",
" [-3695.87115685 -3150.16992906 -4167.56277234 -945.01871031\n",
" -4559.71282389]\n",
"[1821.48322914 1870.45224624 1662.87037422 3759.04988596 2018.72527643]\n",
"Result = Unknown, Time = 0.01 s\n",
"Wrote pyrat formula in file formula.txt\n",
"Output bounds:\n",
" [-282.51928827 -241.91368502 -320.69790512 -73.04203202 -349.31167125]\n",
"[138.68520869 140.0404299 127.11648591 288.24411521 156.15636524]\n",
"Result = Unknown, Time = 0.01 s\n",
"Wrote pyrat formula in file formula.txt\n",
"Output bounds:\n",
" [-1.60438756 -1.94363294 -3.4275154 -3.51762934 -4.9584253 ]\n",
"[6.10297656 7.01423348 7.20661815 8.34524076 8.29646485]\n",
"Result = Unknown, Time = 0.01 s\n",
"Wrote pyrat formula in file formula.txt\n",
"Output bounds:\n",
" [-0.02100108 -0.01882828 -0.01953889 -0.01895731 -0.01863867]\n",
"[-0.0172831 -0.01226754 -0.01223274 -0.01425342 -0.0132472 ]\n",
"Result = True, Time = 0.01 s\n",
"Wrote pyrat formula in file formula.txt\n",
"Output bounds:\n",
" [-0.02078877 -0.01879703 -0.01952387 -0.01893584 -0.01859771]\n",
"[-0.01966997 -0.01630435 -0.01706997 -0.01716751 -0.01701337]\n",
"Result = True, Time = 0.01 s\n",
"Wrote pyrat formula in file formula.txt\n",
"Output bounds:\n",
" [-0.02383511 -0.01882766 -0.03147746 -0.05449972 -0.01881987]\n",
"[0.08913946 0.19379825 0.13595365 0.15602469 0.15905256]\n",
"Result = True, Time = 0.01 s\n",
"Wrote pyrat formula in file formula.txt\n",
"Output bounds:\n",
" [-1.95314583 -2.29833054 -4.66155779 -4.11012113 -6.49521951]\n",
"[ 6.57163797 8.37302401 7.80644603 11.1525727 9.50623178]\n",
"Result = Unknown, Time = 0.01 s\n",
"Wrote pyrat formula in file formula.txt\n",
"Output bounds:\n",
" [-0.02041463 -0.01878185 -0.01951658 -0.01892541 -0.01858828]\n",
"[-0.02007477 -0.01875068 -0.01950159 -0.01890399 -0.0185689 ]\n",
"Result = True, Time = 0.01 s\n",
"Wrote pyrat formula in file formula.txt\n",
"Output bounds:\n",
" [-0.02037451 -0.01878399 -0.0195176 -0.01892688 -0.0185896 ]\n",
"[-0.02005151 -0.01875436 -0.01950336 -0.01890652 -0.01857119]\n",
"Result = True, Time = 0.01 s\n"
]
}
],
"source": [
"def analyse(inputs: List[Interval]):\n",
" formula = create_formula_p1(inputs)\n",
" formula.write_pyrat(\"formula.txt\")\n",
" res, t_pyrat = launch_pyrat(\"acas.nnet\", \"formula.txt\", domains=[\"zono\"])\n",
" \n",
" if(res == True):\n",
" return True \n",
" else:\n",
" i1 = []\n",
" i2 = []\n",
" for i in range(len(inputs)):\n",
" i_div = divide_interval(inputs[i])\n",
" i1.append(i_div[0])\n",
" i2.append(i_div[1])\n",
" \n",
" analyse(i1)\n",
" analyse(i2)\n",
" \n",
"initial = [Interval(0.6, 0.6798577687), Interval(-0.5, 0.5), Interval(-0.5, 0.5), Interval(0.45, 0.5), Interval(-0.5, 0.45)]\n",
"analyse(initial)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You can print some statistic like the number of analysis you do or the time taken. **Can you do less analysis than PyRAT?**\n",
"\n",
"### Heuristics to go further\n",
"\n",
"How could we reach the same number or lower of analysis from PyRAT:\n",
"- Can we select the interval to split more carefully?\n",
"- Can we ignore some of the preliminary analysis because we know they won't succeed?"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"*****\n",
"\n",
"## Part 4: Choosing a network for image classification\n",
"\n",
"For this section we will focus on an open source image dataset for classification purpose. The images we will use are a subset of the Fashion Mnist dataset ([available here](https://www.kaggle.com/datasets/zalando-research/fashionmnist)). This dataset was developed to replace the MNIST dataset which was overused and presents the same caracteristics:\n",
"- 28x28 grayscales images\n",
"- 10 output classes: T-shirt, Trouser, Pullover, Dress, Coat, Sandal, Shirt, Sneaker, Bag, Ankle boot\n",
"\n",
"![fmnist](imgs/fashion_mnist.png)\n",
"\n",
"We will use a subset of 50 images for the purpose of this TP.\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The goal for this part is to decide which model would be best suited for our needs. Let's suppose we are in a critical system where picking the right class for the cloth might lead to any potential damage (ecological, financial, loss of clients, ..). \n",
"\n",
"We propose 5 models trained with different methods:\n",
"1. **Baseline model**, normal training\n",
"2. **Adversarial model**, adversarial training\n",
"3. **Pruned model, normal** training + pruning\n",
"4. **Certified model**, certified training\n",
"5. **Pruned certified model**, certified training + pruning\n",
"\n",
"We must decide on a model to use. The accuracy of the models is already calculated on the whole test set (more than 50 images). This is already a first criteria of choice as some training lead to less accuracy.\n",
"\n",
"|Model | Accuracy|\n",
"|--- | ---|\n",
"|Baseline | 90.50%|\n",
"|Adversarial | 79%|\n",
"|Pruned | 89%|\n",
"|Certified | 72.30%|\n",
"|Pruned Certified | 73.20%|\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Step 1: Local robustness \n",
"\n",
"We will look at the local robustness of the five models around the 50 images from our subdataset. All models and images are available in the `fmnist` folder. We already created two utility functions to read images and to launch PyRAT on an image and a network. `read_images` return a list of `(image_i, label_i)` while `local_robustness` returns the robustness of a network around an image for a given perturbation.\n",
"\n",
"An example is given below with a bag image and a perturbation of 1/255 (1 pixel modified)."
]
},
{
"cell_type": "code",
"execution_count": 84,
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Image shape: (28, 28), label: 7\n",
"Robust: True, Time: 0.39942956200047774\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAGdCAYAAABU0qcqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAdTUlEQVR4nO3df2yV5f3/8dcp0FN+tKeU0p5WChZQcPJjG5PaqQxDQ8HNCLJEnX/A4nS4YqZMXVim6FzSjU/ijAvTJUtkZoLOZEAkC4kWW/ajxYAwwpwNxU6q9IeinEOLLaW9vn/w9cwjP6+b077b8nwkV0LPfb97v3v15rx6n3P3asg55wQAQD9Ls24AAHB5IoAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgYrh1A1/W29urI0eOKDMzU6FQyLodAIAn55yOHz+uwsJCpaWd+zpnwAXQkSNHVFRUZN0GAOASNTU1acKECefcPuBegsvMzLRuAQCQAhd6Pu+zAFq/fr2uvPJKZWRkqKSkRG+99dZF1fGyGwAMDRd6Pu+TAHrllVe0evVqrV27Vm+//bZmz56t8vJytbW19cXhAACDkesDc+fOdRUVFYmPe3p6XGFhoausrLxgbSwWc5IYDAaDMchHLBY77/N9yq+ATp48qT179qisrCzxWFpamsrKylRbW3vG/l1dXYrH40kDADD0pTyAPv74Y/X09Cg/Pz/p8fz8fLW0tJyxf2VlpSKRSGJwBxwAXB7M74Jbs2aNYrFYYjQ1NVm3BADoByn/PaDc3FwNGzZMra2tSY+3trYqGo2esX84HFY4HE51GwCAAS7lV0Dp6emaM2eOqqqqEo/19vaqqqpKpaWlqT4cAGCQ6pOVEFavXq3ly5frG9/4hubOnatnnnlGHR0d+v73v98XhwMADEJ9EkB33HGHPvroIz3++ONqaWnRV7/6VW3fvv2MGxMAAJevkHPOWTfxRfF4XJFIxLoNAMAlisViysrKOud287vgAACXJwIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJhIeQA98cQTCoVCSWP69OmpPgwAYJAb3hef9Nprr9Ubb7zxv4MM75PDAAAGsT5JhuHDhysajfbFpwYADBF98h7QwYMHVVhYqMmTJ+vuu+/W4cOHz7lvV1eX4vF40gAADH0pD6CSkhJt2LBB27dv13PPPafGxkbddNNNOn78+Fn3r6ysVCQSSYyioqJUtwQAGIBCzjnXlwc4duyYJk2apKefflr33HPPGdu7urrU1dWV+DgejxNCADAExGIxZWVlnXN7n98dkJ2drauvvloNDQ1n3R4OhxUOh/u6DQDAANPnvwfU3t6uQ4cOqaCgoK8PBQAYRFIeQA8//LBqamr03//+V//85z+1dOlSDRs2THfddVeqDwUAGMRS/hLcBx98oLvuuktHjx7V+PHjdeONN6qurk7jx49P9aEAAINYn9+E4CsejysSiVi3AQC4RBe6CYG14AAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAICJ4dYNAIPd1KlTvWsyMjK8aw4cOOBd05+GD/d/Ojl16lQfdHKmtLRgP2v39vamuBN8EVdAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATLAYKXCJPv30U++aN99807vmgQce8K6pqanxrulP4XDYu6arq6sPOoEFroAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYYDFS4BJ1d3d714wdO9a75gc/+IF3TdCFO+vq6rxrTp065V3T29vrXZOW5v9zc5DjDHRFRUWB6pqbm71rgnxvLwZXQAAAEwQQAMCEdwDt3LlTt956qwoLCxUKhbRly5ak7c45Pf744yooKNDIkSNVVlamgwcPpqpfAMAQ4R1AHR0dmj17ttavX3/W7evWrdOzzz6r559/Xrt27dLo0aNVXl6uzs7OS24WADB0eN+EsHjxYi1evPis25xzeuaZZ/Tzn/9ct912myTpxRdfVH5+vrZs2aI777zz0roFAAwZKX0PqLGxUS0tLSorK0s8FolEVFJSotra2rPWdHV1KR6PJw0AwNCX0gBqaWmRJOXn5yc9np+fn9j2ZZWVlYpEIokR9NZCAMDgYn4X3Jo1axSLxRKjqanJuiUAQD9IaQBFo1FJUmtra9Ljra2tiW1fFg6HlZWVlTQAAENfSgOouLhY0WhUVVVVicfi8bh27dql0tLSVB4KADDIed8F197eroaGhsTHjY2N2rdvn3JycjRx4kQ9+OCD+uUvf6mrrrpKxcXFeuyxx1RYWKglS5aksm8AwCDnHUC7d+/WzTffnPh49erVkqTly5drw4YNevTRR9XR0aH77rtPx44d04033qjt27crIyMjdV0DAAa9kHPOWTfxRfF4XJFIxLoNDHLXX399oLry8nLvmn/961/eNStWrOiX4xw9etS7RpKqq6u9a/bv3x/oWJAmT57sXfPDH/4w0LE2bdrkXbNv375Ax4rFYud9X9/8LjgAwOWJAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGDC+88xYGAL8hdl29vbAx1r2bJl/XKsICtU792717tGktLS/H8mmzNnjnfNK6+84l1z6tQp75qg5s+f710zd+5c75o//OEP3jUDXZBzqKyszLtm1KhR3jWSNHPmTO+aoKthXwhXQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEyEnHPOuokvisfjikQi1m0MWk899ZR3zXvvvRfoWBMmTPCuOXjwoHdNc3Ozd01mZqZ3jSR1dHR41+Tm5vZLTTQa9a658sorvWsk6cCBA941Y8aM8a4J8n99w4YN3jVBF9MMh8PeNUVFRd41t9xyi3dNLBbzrpGkEydOeNfU1dV57d/b26sPP/xQsVjsvAskcwUEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADAxIBdjDQUCikUCl103dixY72PFWShQen0Qnv9UfPNb37Tu6a9vd27ZtSoUd41kvTJJ5941wRZqDE9Pd27ZsSIEd41UrDzKMj3dty4cd41Qb6mjIwM7xpJmjRpknfN+++/711z6NAh75qvfOUr3jVBFuCUpHfeece7pqCgwLtm6dKl3jWffvqpd40kbdu2zbvGd+HTkydP6sUXX2QxUgDAwEQAAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMDEgF2MNBqNKi3t4vOxsLDQ+1jFxcXeNZLU1tbmXZObm+td4/P1fy4ej3vXRKNR7xop2OKYXV1d3jXXXHONd03QhRqHDx/uXRPkPAqyaGyQBUy7u7u9a6Rg83f99dd711x11VXeNTt27PCuCbrg7rRp07xrpk+f7l3z8ccfe9ccOHDAu0YKdk74/l/v6OjQd77zHRYjBQAMTAQQAMCEdwDt3Ll
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from pyrat_api import read_images, local_robustness\n",
"import matplotlib.pyplot as plt\n",
"%matplotlib inline\n",
"\n",
"images = read_images()\n",
"image_0, label_0 = images[0]\n",
"print(f\"Image shape: {image_0.shape}, label: {label_0}\")\n",
"plt.imshow(image_0, cmap='gray')\n",
"\n",
"res, elapsed = local_robustness(model_path=\"fmnist/baseline.onnx\", image=image_0, label=label_0, pert=1/255, domains=[\"zono\"])\n",
"print(f\"Robust: {res}, Time: {elapsed}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now its your turn create a function that analyse all the given images for a given network and a given perturbation and returns a robustness score including safe images, unknown images and unsafe images as well as a mean time for the analysis."
]
},
{
"cell_type": "code",
"execution_count": 85,
"metadata": {},
"outputs": [
{
"ename": "NotImplementedError",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mNotImplementedError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[85], line 4\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mrobustness\u001b[39m(images, model_path, pert):\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mNotImplementedError\u001b[39;00m\n\u001b[0;32m----> 4\u001b[0m \u001b[43mrobustness\u001b[49m\u001b[43m(\u001b[49m\u001b[43mimages\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mfmnist/baseline.onnx\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[38;5;241;43m/\u001b[39;49m\u001b[38;5;241;43m255\u001b[39;49m\u001b[43m)\u001b[49m \u001b[38;5;66;03m# should return [1, 48, 1]\u001b[39;00m\n",
"Cell \u001b[0;32mIn[85], line 2\u001b[0m, in \u001b[0;36mrobustness\u001b[0;34m(images, model_path, pert)\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mrobustness\u001b[39m(images, model_path, pert):\n\u001b[0;32m----> 2\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mNotImplementedError\u001b[39;00m\n",
"\u001b[0;31mNotImplementedError\u001b[0m: "
]
}
],
"source": [
"def robustness(images, model_path, pert):\n",
" raise NotImplementedError\n",
" \n",
"robustness(images, \"fmnist/baseline.onnx\", 1/255) # should return [1, 48, 1]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"With these results you can now try and plot the results to show the evolution in function of the level of intensity for the perturbation. We aim to see if a model is more robust than another to what could be adversarial perturbation. You can use matplotlib to plot these results. "
]
},
{
"cell_type": "code",
"execution_count": 86,
"metadata": {},
"outputs": [],
"source": [
"def plot_robustness():\n",
" raise NotImplementedError"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Step 2: Metamorphic testing\n",
"\n",
"We checked with the engineer in charge of taking pictures of the clothes on the condition in which he takes pictures. After a careful investigation we realised that the following issues might be happening in our picture:\n",
"- Luminosity of the setting varies from -30 to +30\n",
"- Angle of the clothes might vary from -15° to + 15°\n",
"- Picture can be blurry \n",
"\n",
"In that sense, to test the robustness of our model we will proceed to see if it is sensitive to these perturbations. Following the examples [here](https://opencv-tutorial.readthedocs.io/en/latest/trans/transform.html) use opencv library to implement transformation on the images. You can visualise the image with matplotlib."
]
},
{
"cell_type": "code",
"execution_count": 87,
"metadata": {},
"outputs": [
{
"ename": "NotImplementedError",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mNotImplementedError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[87], line 10\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mblur\u001b[39m(image, intensity):\n\u001b[1;32m 8\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mNotImplementedError\u001b[39;00m\n\u001b[0;32m---> 10\u001b[0m plt\u001b[38;5;241m.\u001b[39mimshow(\u001b[43mblur\u001b[49m\u001b[43m(\u001b[49m\u001b[43mimage_0\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m15\u001b[39;49m\u001b[43m)\u001b[49m)\n",
"Cell \u001b[0;32mIn[87], line 8\u001b[0m, in \u001b[0;36mblur\u001b[0;34m(image, intensity)\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mblur\u001b[39m(image, intensity):\n\u001b[0;32m----> 8\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mNotImplementedError\u001b[39;00m\n",
"\u001b[0;31mNotImplementedError\u001b[0m: "
]
}
],
"source": [
"def luminosity(image, intensity):\n",
" raise NotImplementedError\n",
" \n",
"def rotation(image, angle):\n",
" raise NotImplementedError\n",
"\n",
"def blur(image, intensity):\n",
" raise NotImplementedError\n",
" \n",
"plt.imshow(blur(image_0, 15))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To apply these transformation and compute the stability of a network towards we will use a tool developed at CEA called AIMOS which does exactly that.\n",
"\n",
"After installing the aimos module (which is provided locally), the function below will call aimos on your transformation and a specified range before plotting the result."
]
},
{
"cell_type": "code",
"execution_count": 88,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Processing ./bin/aimos_compile\n",
"Building wheels for collected packages: AIMOS\n",
" Building wheel for AIMOS (setup.py) ... \u001b[?25ldone\n",
"\u001b[?25h Created wheel for AIMOS: filename=AIMOS-1.0-py3-none-any.whl size=27865 sha256=b6149cb0febece22984ce1e9319ee6cb64109d5fac6dca16f576ffb60d8673fa\n",
" Stored in directory: /tmp/pip-ephem-wheel-cache-mr8hxog1/wheels/70/50/9d/c65b449093c72794e4464fe38d49de849130fb6a47de22aede\n",
"Successfully built AIMOS\n",
"Installing collected packages: AIMOS\n",
"Successfully installed AIMOS-1.0\n",
"\u001b[33mWARNING: You are using pip version 20.3.3; however, version 22.3.1 is available.\n",
"You should consider upgrading via the '/home/NEMO18/.virtualenvs/tuto_seti/bin/python -m pip install --upgrade pip' command.\u001b[0m\n"
]
}
],
"source": [
"!pip install ./bin/aimos_compile\n",
"from aimos import core"
]
},
{
"cell_type": "code",
"execution_count": 89,
"metadata": {
"scrolled": false
},
"outputs": [
{
"ename": "NotImplementedError",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mNotImplementedError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[89], line 24\u001b[0m\n\u001b[1;32m 12\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m y\n\u001b[1;32m 14\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m core\u001b[38;5;241m.\u001b[39mmain(\n\u001b[1;32m 15\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfmnist/images/\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 16\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfmnist/\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 21\u001b[0m verbose\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m,\n\u001b[1;32m 22\u001b[0m )\n\u001b[0;32m---> 24\u001b[0m res \u001b[38;5;241m=\u001b[39m \u001b[43mcall_aimos\u001b[49m\u001b[43m(\u001b[49m\u001b[43mluminosity\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mrange\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m10\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n",
"Cell \u001b[0;32mIn[89], line 14\u001b[0m, in \u001b[0;36mcall_aimos\u001b[0;34m(transformation, value_range)\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21midentity\u001b[39m(y):\n\u001b[1;32m 12\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m y\n\u001b[0;32m---> 14\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mcore\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmain\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 15\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mfmnist/images/\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 16\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mfmnist/\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 17\u001b[0m \u001b[43m \u001b[49m\u001b[43m(\u001b[49m\u001b[43mtransformation\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43midentity\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 18\u001b[0m \u001b[43m \u001b[49m\u001b[43mcustom_load\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mload_image\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 19\u001b[0m \u001b[43m \u001b[49m\u001b[43mfn_range\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mvalue_range\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 20\u001b[0m \u001b[43m \u001b[49m\u001b[43msingle_plot\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 21\u001b[0m \u001b[43m \u001b[49m\u001b[43mverbose\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 22\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m.\\aimos\\core.py:408\u001b[0m, in \u001b[0;36mmain\u001b[0;34m(inputs_path, models_path, transformation, alt_predict, custom_predict, custom_load, custom_global_load, fn_range, out_mode, per_input, custom_comparison, single_plot, save_single_fig, verbose, individual_load, export_path)\u001b[0m\n",
"File \u001b[0;32m.\\aimos\\core.py:254\u001b[0m, in \u001b[0;36m_loop_over_inputs\u001b[0;34m(inputs, inputs_name, model, predict_fn, alt_predict_fn, out_mode, custom_comparison, transform, transformation, per_input, range_value, verbose, custom_load, individual_load, export_path)\u001b[0m\n",
"File \u001b[0;32m.\\aimos\\core.py:169\u001b[0m, in \u001b[0;36mcompute_results\u001b[0;34m(transformation, input_ori, predict_fn, alt_predict_fn, out_mode, comparison_fn, export_path)\u001b[0m\n",
"File \u001b[0;32mC:\\Users\\AL253370\\Documents\\Projets\\AIMOS\\git_aimos\\aimos\\core_functions.py:138\u001b[0m, in \u001b[0;36m<lambda>\u001b[0;34m(x)\u001b[0m\n",
"Cell \u001b[0;32mIn[87], line 2\u001b[0m, in \u001b[0;36mluminosity\u001b[0;34m(image, intensity)\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mluminosity\u001b[39m(image, intensity):\n\u001b[0;32m----> 2\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mNotImplementedError\u001b[39;00m\n",
"\u001b[0;31mNotImplementedError\u001b[0m: "
]
}
],
"source": [
"import numpy as np\n",
"import cv2\n",
"\n",
"%matplotlib inline\n",
"\n",
"\n",
"def load_image(path):\n",
" return np.load(str(path)).astype(np.float32).reshape((1, 28, 28))\n",
"\n",
"def call_aimos(transformation, value_range):\n",
" def identity(y):\n",
" return y\n",
" \n",
" return core.main(\n",
" \"fmnist/images/\",\n",
" \"fmnist/\",\n",
" (transformation, identity),\n",
" custom_load=load_image,\n",
" fn_range=value_range,\n",
" single_plot=True,\n",
" verbose=False,\n",
" )\n",
"\n",
"res = call_aimos(luminosity, range(0, 10))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"At this point, you can run the above function on the different transformations that we want to test to see the different responses from the models. \n",
"\n",
"**What do you notice?**\n",
"\n",
"### Step 3: Choose a model\n",
"\n",
"**With all of these results, what model would you choose, why?**\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
}
},
"nbformat": 4,
"nbformat_minor": 4
}