From ba98a7fd4b78ce88b4a07b41fbb8f55fa030641d Mon Sep 17 00:00:00 2001 From: Arthur Date: Thu, 11 Feb 2021 09:00:27 +0100 Subject: [PATCH] Ajout de code et modifications --- Apprentissage_MB.ipynb | 628 ++++++++++++++++++++ Apprentissage_MSELoss_avec_GPU.ipynb | 505 ---------------- Apprentissage_initial_dataset.ipynb | 38 +- Benchmark.ipynb | 14 +- Benchmark_MB.ipynb | 723 +++++++++++++++++++++++ display_bench.ipynb | 853 ++++++++++++++++++++++++++- view_weights.ipynb | 14 +- 7 files changed, 2220 insertions(+), 555 deletions(-) create mode 100644 Apprentissage_MB.ipynb delete mode 100755 Apprentissage_MSELoss_avec_GPU.ipynb create mode 100644 Benchmark_MB.ipynb diff --git a/Apprentissage_MB.ipynb b/Apprentissage_MB.ipynb new file mode 100644 index 0000000..74fe029 --- /dev/null +++ b/Apprentissage_MB.ipynb @@ -0,0 +1,628 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "#Tous les codes sont basés sur l'environnement suivant\n", + "#python 3.7\n", + "#opencv 3.1.0\n", + "#pytorch 1.4.0\n", + "\n", + "import torch\n", + "from torch.autograd import Variable\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "import cv2\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import random\n", + "import math\n", + "import pickle\n", + "import random\n", + "from PIL import Image\n", + "import sys\n", + "from IPython.display import clear_output\n", + "from datetime import datetime" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "#Les fonctions dans ce bloc ne sont pas utilisées par le réseau, mais certaines fonctions d'outils\n", + "\n", + "\n", + "def tensor_imshow(im_tensor,cannel):\n", + " b,c,h,w=im_tensor.shape\n", + " if c==1:\n", + " plt.imshow(im_tensor.squeeze().detach().numpy())\n", + " else:\n", + " plt.imshow(im_tensor.squeeze().detach().numpy()[cannel,:])\n", + "\n", + "# Obtenez des données d'entraînement\n", + "# frag,vt=get_training_fragment(frag_size,image)\n", + "# frag est un patch carrée de taille (frag_size*frag_size) a partir du image(Son emplacement est aléatoire)\n", + "# vt est la vérité terrain de la forme Dirac.\n", + "def get_training_fragment(frag_size,im):\n", + " h,w,c=im.shape\n", + " n=random.randint(0,int(h/frag_size)-1)\n", + " m=random.randint(0,int(w/frag_size)-1) \n", + " shape=frag_size/4\n", + " vt_h=math.ceil((h+1)/shape)\n", + " vt_w=math.ceil((w+1)/shape)\n", + " vt=np.zeros([vt_h,vt_w])\n", + " vt_h_po=round((vt_h-1)*(n*frag_size/(h-1)+(n+1)*frag_size/(h-1))/2)\n", + " vt_w_po=round((vt_w-1)*(m*frag_size/(w-1)+(m+1)*frag_size/(w-1))/2)\n", + " vt[vt_h_po,vt_w_po]=1\n", + " vt = np.float32(vt)\n", + " vt=torch.from_numpy(vt.reshape(1,1,vt_h,vt_w))\n", + " \n", + " return im[n*frag_size:(n+1)*frag_size,m*frag_size:(m+1)*frag_size,:],vt\n", + "\n", + "# Cette fonction convertit l'image en variable de type Tensor.\n", + "# Toutes les données de calcul du réseau sont de type Tensor\n", + "# Img.shape=[Height,Width,Channel]\n", + "# Tensor.shape=[Batch,Channel,Height,Width]\n", + "def img2tensor(im):\n", + " im=np.array(im,dtype=\"float32\")\n", + " tensor_cv = torch.from_numpy(np.transpose(im, (2, 0, 1)))\n", + " im_tensor=tensor_cv.unsqueeze(0)\n", + " return tensor_cv\n", + "\n", + "# Trouvez les coordonnées de la valeur maximale dans une carte de corrélation\n", + "# x,y=show_coordonnee(carte de corrélation)\n", + "def show_coordonnee(position_pred):\n", + " map_corre=position_pred.squeeze().detach().numpy()\n", + " h,w=map_corre.shape\n", + " max_value=map_corre.max()\n", + " coordonnee=np.where(map_corre==max_value)\n", + " return coordonnee[0].mean()/h,coordonnee[1].mean()/w\n", + "\n", + "# Filtrer les patchs en fonction du nombre de pixels noirs dans le patch\n", + "# Si seuls les pixels non noirs sont plus grands qu'une certaine proportion(seuillage), revenez à True, sinon False\n", + "def test_fragment32_32(frag,seuillage):\n", + " a=frag[:,:,0]+frag[:,:,1]+frag[:,:,2]\n", + " mask = (a == 0)\n", + " arr_new = a[mask]\n", + " if arr_new.size/a.size<=(1-seuillage):\n", + " return True\n", + " else:\n", + " return False\n", + " \n", + "# Ces deux fonctions permettent de sauvegarder le réseau dans un fichier\n", + "# ou de load le réseau stocké à partir d'un fichier\n", + "def save_net(file_path,net):\n", + " pkl_file = open(file_path, 'wb')\n", + " pickle.dump(net,pkl_file)\n", + " pkl_file.close()\n", + "def load_net(file_path): \n", + " pkl_file = open(file_path, 'rb')\n", + " net= pickle.load(pkl_file)\n", + " pkl_file.close()\n", + " return net" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "# Créer un poids de type DeepMatch comme valeur initiale de Conv1 (non obligatoire)\n", + "def ini():\n", + " kernel=torch.zeros([8,3,3,3])\n", + " array_0=np.array([[1,2,1],[0,0,0],[-1,-2,-1]],dtype='float32')\n", + " array_1=np.array([[2,1,0],[1,0,-1],[0,-1,-2]],dtype='float32')\n", + " array_2=np.array([[1,0,-1],[2,0,-2],[1,0,-1]],dtype='float32')\n", + " array_3=np.array([[0,-1,-2],[1,0,-1],[2,1,0]],dtype='float32')\n", + " array_4=np.array([[-1,-2,-1],[0,0,0],[1,2,1]],dtype='float32')\n", + " array_5=np.array([[-2,-1,0],[-1,0,1],[0,1,2]],dtype='float32')\n", + " array_6=np.array([[-1,0,1],[-2,0,2],[-1,0,1]],dtype='float32')\n", + " array_7=np.array([[0,1,2],[-1,0,1],[-2,-1,0]],dtype='float32')\n", + " for i in range(3):\n", + " kernel[0,i,:]=torch.from_numpy(array_0)\n", + " kernel[1,i,:]=torch.from_numpy(array_1)\n", + " kernel[2,i,:]=torch.from_numpy(array_2)\n", + " kernel[3,i,:]=torch.from_numpy(array_3)\n", + " kernel[4,i,:]=torch.from_numpy(array_4)\n", + " kernel[5,i,:]=torch.from_numpy(array_5)\n", + " kernel[6,i,:]=torch.from_numpy(array_6)\n", + " kernel[7,i,:]=torch.from_numpy(array_7)\n", + " return torch.nn.Parameter(kernel,requires_grad=True) \n", + "\n", + "# Calculer le poids initial de la couche convolutive add\n", + "# n, m signifie qu'il y a n * m sous-patches dans le patch d'entrée\n", + "# Par exemple, le patch d'entrée est 16 * 16, pour les patchs 4 * 4 de la première couche, n = 4, m = 4\n", + "# pour les patchs 8 * 8 de la deuxième couche, n = 2, m = 2\n", + "def kernel_add_ini(n,m):\n", + " input_canal=int(n*m)\n", + " output_canal=int(n/2)*int(m/2)\n", + " for i in range(int(n/2)):\n", + " for j in range(int(m/2)):\n", + " kernel_add=np.zeros([1,input_canal],dtype='float32')\n", + " kernel_add[0,i*2*m+j*2]=1\n", + " kernel_add[0,i*2*m+j*2+1]=1\n", + " kernel_add[0,(i*2+1)*m+j*2]=1\n", + " kernel_add[0,(i*2+1)*m+j*2+1]=1\n", + " if i==0 and j==0:\n", + " add=torch.from_numpy(kernel_add.reshape(1,input_canal,1,1))\n", + " else:\n", + " add_=torch.from_numpy(kernel_add.reshape(1,input_canal,1,1))\n", + " add=torch.cat((add,add_),0)\n", + " return torch.nn.Parameter(add,requires_grad=False) \n", + "\n", + "# Calculer le poids initial de la couche convolutive shift\n", + "# shift+add Peut réaliser l'étape de l'agrégation\n", + "# Voir ci-dessus pour les paramètres n et m. \n", + "# Pour des étapes plus détaillées, veuillez consulter mon rapport de stage\n", + "def kernel_shift_ini(n,m):\n", + " input_canal=int(n*m)\n", + " output_canal=int(n*m)\n", + " \n", + " kernel_shift=torch.zeros([output_canal,input_canal,3,3])\n", + " \n", + " array_0=np.array([[1,0,0],[0,0,0],[0,0,0]],dtype='float32')\n", + " array_1=np.array([[0,0,1],[0,0,0],[0,0,0]],dtype='float32')\n", + " array_2=np.array([[0,0,0],[0,0,0],[1,0,0]],dtype='float32')\n", + " array_3=np.array([[0,0,0],[0,0,0],[0,0,1]],dtype='float32')\n", + " \n", + " kernel_shift_0=torch.from_numpy(array_0)\n", + " kernel_shift_1=torch.from_numpy(array_1)\n", + " kernel_shift_2=torch.from_numpy(array_2)\n", + " kernel_shift_3=torch.from_numpy(array_3)\n", + " \n", + " \n", + " for i in range(n):\n", + " for j in range(m):\n", + " if i==0 and j==0:\n", + " kernel_shift[0,0,:]=kernel_shift_0\n", + " else:\n", + " if i%2==0 and j%2==0:\n", + " kernel_shift[i*m+j,i*m+j,:]=kernel_shift_0\n", + " if i%2==0 and j%2==1:\n", + " kernel_shift[i*m+j,i*m+j,:]=kernel_shift_1\n", + " if i%2==1 and j%2==0:\n", + " kernel_shift[i*m+j,i*m+j,:]=kernel_shift_2\n", + " if i%2==1 and j%2==1:\n", + " kernel_shift[i*m+j,i*m+j,:]=kernel_shift_3\n", + " \n", + " return torch.nn.Parameter(kernel_shift,requires_grad=False) \n", + "\n", + "# Trouvez le petit patch(4 * 4) dans la n ème ligne et la m ème colonne du patch d'entrée\n", + "# Ceci est utilisé pour calculer la convolution et obtenir la carte de corrélation\n", + "def get_patch(fragment,psize,n,m):\n", + " return fragment[:,:,n*psize:(n+1)*psize,m*psize:(m+1)*psize]\n", + "\n", + "###################################################################################################################\n", + "class Net(nn.Module):\n", + " def __init__(self,frag_size,psize):\n", + " super(Net, self).__init__()\n", + " \n", + " h_fr=frag_size\n", + " w_fr=frag_size\n", + " \n", + " n=int(h_fr/psize) # n*m patches dans le patch d'entrée\n", + " m=int(w_fr/psize)\n", + " \n", + " self.conv1 = nn.Conv2d(3,8,kernel_size=3,stride=1,padding=1)\n", + " # Si vous souhaitez initialiser Conv1 avec les poids de DeepMatch, exécutez la ligne suivante\n", + " # self.conv1.weight=ini()\n", + " self.Relu = nn.ReLU(inplace=True)\n", + " self.maxpooling=nn.MaxPool2d(3,stride=2, padding=1)\n", + " \n", + " self.shift1=nn.Conv2d(n*m,n*m,kernel_size=3,stride=1,padding=1)\n", + " self.shift1.weight=kernel_shift_ini(n,m)\n", + " self.add1 = nn.Conv2d(n*m,int(n/2)*int(m/2),kernel_size=1,stride=1,padding=0)\n", + " self.add1.weight=kernel_add_ini(n,m)\n", + " \n", + " n=int(n/2)\n", + " m=int(m/2)\n", + " if n>=2 and m>=2:# Si n=m=1,Notre réseau n'a plus besoin de plus de couches pour agréger les cartes de corrélation\n", + " self.shift2=nn.Conv2d(n*m,n*m,kernel_size=3,stride=1,padding=1)\n", + " self.shift2.weight=kernel_shift_ini(n,m)\n", + " self.add2 = nn.Conv2d(n*m,int(n/2)*int(m/2),kernel_size=1,stride=1,padding=0)\n", + " self.add2.weight=kernel_add_ini(n,m)\n", + " \n", + " n=int(n/2)\n", + " m=int(m/2)\n", + " if n>=2 and m>=2:\n", + " self.shift3=nn.Conv2d(n*m,n*m,kernel_size=3,stride=1,padding=1)\n", + " self.shift3.weight=kernel_shift_ini(n,m)\n", + " self.add3 = nn.Conv2d(n*m,int(n/2)*int(m/2),kernel_size=1,stride=1,padding=0)\n", + " self.add3.weight=kernel_add_ini(n,m)\n", + " \n", + " def get_descripteur(self,img,using_cuda):\n", + " # Utilisez Conv1 pour calculer le descripteur,\n", + " descripteur_img=self.Relu(self.conv1(img))\n", + " b,c,h,w=descripteur_img.shape\n", + " couche_constante = 0.5 * torch.ones([b, 1, h, w])\n", + " if using_cuda:\n", + " couche_constante=couche_constante.cuda()\n", + " # Ajouter une couche constante pour éviter la division par 0 lors de la normalisation\n", + " descripteur_img = torch.cat((descripteur_img,couche_constante),1)\n", + " # la normalisation\n", + " descripteur_img_norm = F.normalize(descripteur_img)\n", + " return descripteur_img_norm\n", + " \n", + " def forward(self,img,frag,using_cuda):\n", + " psize=4\n", + " # Utilisez Conv1 pour calculer le descripteur,\n", + " descripteur_input2=self.get_descripteur(frag,using_cuda)\n", + " descripteur_input1=self.get_descripteur(img,using_cuda)\n", + " \n", + " b,c,h,w=frag.shape\n", + " n=int(h/psize)\n", + " m=int(w/psize)\n", + " \n", + " db,dc,dh,dw = descripteur_input1.shape\n", + " \n", + " #######################################\n", + " # Calculer la carte de corrélation par convolution pour les n*m patchs plus petit.\n", + " for i in range(n):\n", + " for j in range(m):\n", + " if i==0 and j==0:\n", + " ##HAD TO CHANGE THIS LINE BECAUSE OF CONVOLUTION DIMENSION FOR BATCHES\n", + " map_corre=F.conv2d(descripteur_input1.view(1,db*dc,dh,dw),get_patch(descripteur_input2,psize,i,j),padding=2,groups=db)\n", + "\n", + " map_corre=map_corre.view(db,1,map_corre.size(2),map_corre.size(3))\n", + " else:\n", + " a=F.conv2d(descripteur_input1.view(1,db*dc,dh,dw),get_patch(descripteur_input2,psize,i,j),padding=2, groups=db)\n", + " a=a.view(db,1,a.size(2),a.size(3))\n", + " map_corre=torch.cat((map_corre,a),1)\n", + " \n", + " ########################################\n", + " # Étape de polymérisation\n", + " map_corre=self.maxpooling(map_corre)\n", + " map_corre=self.shift1(map_corre)\n", + " map_corre=self.add1(map_corre)\n", + " \n", + " #########################################\n", + " # Répétez l'étape d'agrégation jusqu'à obtenir le graphique de corrélation du patch d'entrée\n", + " n=int(n/2)\n", + " m=int(m/2)\n", + " if n>=2 and m>=2:\n", + " map_corre=self.maxpooling(map_corre)\n", + " map_corre=self.shift2(map_corre)\n", + " map_corre=self.add2(map_corre)\n", + " \n", + " \n", + " n=int(n/2)\n", + " m=int(m/2)\n", + " if n>=2 and m>=2:\n", + " map_corre=self.maxpooling(map_corre)\n", + " map_corre=self.shift3(map_corre)\n", + " map_corre=self.add3(map_corre)\n", + " \n", + " \n", + " #b,c,h,w=map_corre.shape\n", + " # Normalisation de la division par maximum\n", + " map_corre=map_corre/map_corre.max()\n", + " # Normalisation SoftMax\n", + " #map_corre=(F.softmax(map_corre.reshape(1,1,h*w,1),dim=2)).reshape(b,c,h,w)\n", + " return map_corre" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Dataset and Dataloader" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import re\n", + "from PIL import Image\n", + "from torchvision import transforms\n", + "from torch.utils.data import Dataset, DataLoader\n", + "from time import time\n", + "\n", + "\n", + "class FragmentDataset(Dataset):\n", + " def __init__(\n", + " self,\n", + " fragments_path, \n", + " train, \n", + " frags_transform=transforms.ToTensor(),\n", + " fresques_transform=None,\n", + " vts_transform=None,\n", + " ):\n", + " \"\"\"\n", + " Parameters\n", + " ----------\n", + " fragments_path: str\n", + " Path to root flder with fragments folders.\n", + " train: boolean\n", + " True for train set (__dev__) False for test (__bench__)\n", + " frags_transform: torchvision.transform\n", + " Tranform to apply to all fragment images. Default: ToTensor()\n", + " fresques_transform: torchvision.transform\n", + " Transform to apply to all fresque images. frags_transform if None.\n", + " vts_transform: transform to apply to all vts images. Default: ToTensor().\n", + " \"\"\"\n", + " self.base_path = fragments_path\n", + " self.frags_transform = frags_transform\n", + " self.fresques_transform = fresques_transform if fresques_transform else frags_transform\n", + " self.fragments_list = []\n", + " self.vts_transform = vts_transform\n", + " \n", + " # To separate between train (dev) and test fragments(bench)\n", + " self.match_expr = \"_dev_\" if train else \"_bench_\"\n", + " \n", + " fragments_path = os.path.join(self.base_path, \"fragments\")\n", + " for fresque_dir in os.listdir(fragments_path):\n", + " current_path = os.path.join(fragments_path, fresque_dir)\n", + " \n", + " if \"fresque\" in current_path: \n", + " # Avoids looking at extra files in the dirs.\n", + " \n", + " # Get path to current fresque (ie: ..path/fresque0.ppm).\n", + " fresque_name = current_path.split(\"/\")[-1] + \".ppm\"\n", + " full_fresque_path = os.path.join(self.base_path, fresque_name) \n", + " \n", + " # Get path to every fragment for that fresque (ie: ..path/fresque0/frag_bench_000.ppm)\n", + " all_fragments_fresque = sorted(os.listdir(current_path))\n", + " \n", + " #Get path to every vt for that fresque (ie: ..path/fresque0/vt/frag_bench_000.ppm))\n", + " vts_path = os.path.join(current_path, \"vt\")\n", + " all_vts_fresque = sorted(os.listdir(vts_path))\n", + " \n", + " # Keep fragments that belong in that set (Train | Test) \n", + " # group them with the full fresque path (tuple)\n", + " all_fragments_fresque = [\n", + " (os.path.join(current_path, frag_path), full_fresque_path, os.path.join(vts_path, vt_path))\n", + " for frag_path, vt_path in zip(all_fragments_fresque, all_vts_fresque)\n", + " if re.search(self.match_expr, frag_path) and re.search(self.match_expr, vt_path)\n", + " ]\n", + " \n", + " self.fragments_list.extend(all_fragments_fresque)\n", + " \n", + " def __len__(self):\n", + " return len(self.fragments_list)\n", + " \n", + " def __getitem__(self, idx):\n", + " # Loads the fragment and the full fresque as a tensor.\n", + " fragment = Image.open(self.fragments_list[idx][0])\n", + " fresque = Image.open(self.fragments_list[idx][1])\n", + " \n", + " with open(self.fragments_list[idx][2],'r') as f:\n", + " data_vt_raw = f.readlines()\n", + " data_vt = [int(d.rstrip('\\r\\n')) for d in data_vt_raw]\n", + " \n", + " # Construct vt\n", + " vt = np.zeros((int(data_vt[0]/4)+1,int(data_vt[1]/4)+1))\n", + " vt[int(data_vt[2]/4),int(data_vt[3]/4)] = 1\n", + " vt = np.float32(vt)\n", + " vt = torch.from_numpy(vt.reshape(1,int(data_vt[0]/4)+1,int(data_vt[1]/4)+1))\n", + " \n", + " return self.frags_transform(fragment), self.fresques_transform(fresque), vt" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Usage" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "fresques_tnsf = transforms.Compose([\n", + " transforms.Resize((1000, 1000)),\n", + " transforms.ToTensor()\n", + "])\n", + "\n", + "train = FragmentDataset(fragments_path=\"training_data_random_shift_color\", train=True, fresques_transform=fresques_tnsf)\n", + "test = FragmentDataset(fragments_path=\"training_data_random_shift_color\", train=False, fresques_transform=fresques_tnsf)\n", + "\n", + "bs = 4\n", + "\n", + "train_loader = DataLoader(train, batch_size=bs, num_workers=4, pin_memory=False, shuffle = True)\n", + "test_loader = DataLoader(test, batch_size=bs, num_workers=4, pin_memory=False, shuffle = True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Train" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[EPOCH 2] Batch 4499/4500\n", + "Temps par batch: 2.74\n", + "Done with epoch 2\n", + "Net sauvegardés dans ./trained_net/net_trainned_MB4_02-10_20-49_0003\n", + "Poids sauvegardés dans ./trained_net/save_weights_MB4_02-10_20-49_0003\n" + ] + } + ], + "source": [ + "frag_size = 16\n", + "psize = 4\n", + "\n", + "net = Net(frag_size, psize).cuda()\n", + " \n", + "optimizer = torch.optim.Adam(net.parameters(), lr=0.001)\n", + "loss_func = torch.nn.MSELoss()\n", + "\n", + "num_epochs = 3\n", + "\n", + "loss_value = []\n", + "para_value = []\n", + "w_values = []\n", + "\n", + "time_old = time()\n", + "\n", + "for epoch in range(num_epochs):\n", + " i=0\n", + " for fragments, fresques, vts in train_loader:\n", + " \n", + " clear_output(wait=True)\n", + " print(\"[EPOCH {}] Batch {}/{}\\nTemps par batch: {:.3}\".format(epoch,i,len(train_loader),time()-time_old))\n", + " time_old = time()\n", + "\n", + " fragments = fragments.cuda()\n", + " fresques = fresques.cuda()\n", + " \n", + " preds = net(fresques, fragments, True) \n", + " optimizer.zero_grad()\n", + " \n", + " del(fragments)\n", + " del(fresques)\n", + " vts = vts.cuda()\n", + " cost = loss_func(vts, preds)\n", + " cost.backward()\n", + " del(vts)\n", + " optimizer.step()\n", + " \n", + " if i%10==0:\n", + " w_values.append(net.conv1.weight.data.cpu().numpy())\n", + " i+=1\n", + "\n", + " loss_value.append(cost.item())\n", + " torch.cuda.empty_cache \n", + " print('Done with epoch ', epoch)\n", + " \n", + "# Sauvegarder le réseau\n", + "save_dir = './trained_net/'\n", + "expe_id = 3\n", + "net_filename = save_dir + \"net_trainned_MB{}_{}_{:04}\".format(bs,datetime.now().strftime(\"%m-%d_%H-%M\"),expe_id)\n", + "save_net(net_filename,net)\n", + "\n", + "# Sauvegarder les poids\n", + "poids_filename = save_dir + \"save_weights_MB{}_{}_{:04}\".format(bs,datetime.now().strftime(\"%m-%d_%H-%M\"),expe_id)\n", + "with open(poids_filename,'wb') as f:\n", + " pickle.dump(w_values,f)\n", + "\n", + "print(\"Net sauvegardés dans {}\".format(net_filename))\n", + "print(\"Poids sauvegardés dans {}\".format(poids_filename))\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "13500" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(loss_value)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[]" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYAAAAD4CAYAAADlwTGnAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8li6FKAAAgAElEQVR4nO3dd3wUdfoH8M+TRkloIfSWAAFE6ZFipUpT+Ymeh717nHq28xREucMGeqdnl0PU87ArKCpNpEhHAkgJNQkt1EQgtIS07++Pnd3M7s6WJJudnd3P+/Xyxe7Md3eeRJhnvl2UUiAiosgTZXYARERkDiYAIqIIxQRARBShmACIiCIUEwARUYSKMevCSUlJKjk52azLExFZ0vr16/OUUo0C8V2mJYDk5GSkp6ebdXkiIksSkX2B+i42ARERRSgmACKiCMUEQEQUoZgAiIgiFBMAEVGEYgIgIopQTABERBHKcglAKYWZ63NQUFRqdihERJZmuQSwJvs4/vr1Jjz34zazQyEisjTLJYCNB04AADIO5ZscCRGRtVkuAezLOwcAOHqq0ORIiIiszXIJIDZGAADFpdzKkoioKqyXAKJtIReXlJkcCRGRtVkuAcRpCeD0+RKTIyEisjbLJYCYaHG8VorNQERElWW5BBAdVR7yp2v3mxgJEZG1WS8BSHkNYNvhUyZGQkRkbdZLALqI2QRERFR5lksAUVH6PgATAyEisjjLJQB9E1AZMwARUaVZLwGwBkBEFBDWTgAmxkFEZHWWTgBsAiIiqjzLJQDR9QGc5WxgIqJKs1wCiNHVAHJOFHAoKBFRJVkuAfRoXd/xOuPQKUycnWFiNERE1mW5BJCSFO/0fsaafcg/V2xSNERE1mW5BCAQt2N5Z8+bEAkRkbVZLwG43/+JiKgSLJcAYqPdQ357caYJkRARWZtfCUBEhonIThHJFJFxBufricgPIrJJRDJE5K7Ah+rZtxsPBvNyRERhwWcCEJFoAO8AGA6gM4CbRKSzS7EHAWxTSnUD0B/AqyISF+BYiYgogPypAfQGkKmUylZKFQH4AsAolzIKQB2xzdJKAHAcAGdpERGFMH8SQAsAB3Tvc7Rjem8DuADAIQBbADyilHLbtV1E7heRdBFJz83NrWTIREQUCP4kAKNxN67Tb4cC+A1AcwDdAbwtInXdPqTUNKVUmlIqrVGjRhUOloiIAsefBJADoJXufUvYnvT17gIwS9lkAtgDoFNgQiQiourgTwJYByBVRFK0jt0xAL53KbMfwCAAEJEmADoCyA5koEREFFg+E4BSqgTAQwAWANgO4CulVIaIjBWRsVqx5wFcIiJbACwC8JRSKq+6gjYyf+uRYF6OiMjyYvwppJSaC2Cuy7GputeHAFwV2NAqZuwn65H10ggAznsGEBGRMcvNBPam3dNz0euFhWaHQURkCWGVAADgJFcGJSLyS9glACIi8g8TABFRhGICICKKUEwAREQRigmAiChCMQEQEUUoSyaAZ6923Y6AiIgqypIJ4J7LUryeX7E7D7mnzyO/gHMCiIg88WspCKu59YO1jtd7p4w0MRIiotBlyRoAERFVHRMAEVGEYgIgIopQTABERBGKCYCIKEIxARARRSgmACKiCGXZBNAqsZbZIRARWZplE0BKUoLZIRARWZplEwAREVUNEwARUYRiAiAiilBhnwBKy5TZIRARhaSwTwB/+3oTl4UmIjIQ9glg1saD6DbpJ7PDICIKOWGfAIiIyFjEJYAnvt6Eq/79i9lhEBGZLix3BDMyfXk2CotL8c36HLNDISIKCZZNAMMubIplu3L9Lv/CnO3VGA0RkfVYtgmof8dGZodARGRplk0AMVESkO9ZuO0onvthG5LHzcGWnPyAfCcRkRVYtgmocd2aAfme+/6X7nj9w+ZD6NKyXkC+l4go1Fm2BlBVby7ajYKiUqdjSnHWMBFFDsvWAKrqtYW7cPRUodMx3v+JKJJEbA0AAPLOnDc7BCIi00R0AoiJiugfn4giXETfAaNcRhKxBYiIIolfCUBEhonIThHJFJFxHsr0F5HfRCRDRCyx1kK0y0hS9gEQUSTx2QksItEA3gEwBEAOgHUi8r1SapuuTH0A7wIYppTaLyKNqyvgQHKtARARRRJ/agC9AWQqpbKVUkUAvgAwyqXMzQBmKaX2A4BS6lhgw6we0eKcAD5cucetzGdr92PWBq4fREThx58E0ALAAd37HO2YXgcADURkqYisF5Hbjb5IRO4XkXQRSc/N9X8dH0/+dGVb1KlR+ZGsX/uxMNzT327B419tqvQ1iIhClT8JwKidxLW1PAZALwAjAQwF8KyIdHD7kFLTlFJpSqm0Ro2qvpbP+OEXYMukoVX+HiKiSORPAsgB0Er3viWAQwZl5iulziql8gAsA9AtMCEG16GTBWaHQEQUFP4kgHUAUkUkRUTiAIwB8L1LmdkALheRGBGpDaAPAEuuv3zJlMU45jJDGACmLcviUhFEFFZ8JgClVAmAhwAsgO2m/pVSKkNExorIWK3MdgDzAWwG8CuA6UqprdUXtrM6NQO7osW6vSfcjr00dwf2Hz8X0OsQEZnJr3kASqm5SqkOSql2SqkXtWNTlVJTdWX+qZTqrJS6SCn1enUFbOSze/sG9Pu2Hz6F5HFzsDnnpNPxkjLWAIgofITFTGAV4Dm8by/JBAAsyDjifB3e/4kojIRFAqgu7yzJcjliywCFxaV4Zf4OFBaXun+IiMgimAAqwN4C9P6ybLy7NAsfrdzrOJd7+jxWZ/1uTmBERJUQFgkgWE0z9usUlZbZ/iwpc5y7/r1VuOn9NcEJhIgoAMIiAdSOiw7Kdex9DUYz4zhCiIisJiwSQGqTOkG5TlmZ8/tAdz4TEQVTWCQAAGjTsHa1X8NxwxeuIkpE1hc2CSAYt+QPV+x1es9hoURkZWGTAIJh5oYc5Jw4F5RkQ0RU3cImAUy9rVdQrqPvB8g4lI9P1+4LynWJiAItbBJAp6Z1cV0P120KAu/5OY6N0PDz9mOY8G3QljwiIgqowK6iZrJgNM0s3HYUDWrHBuFKRETVK2xqAACCkwEACHsBiCgMhFUCiAqB4ZmLdxzlpjJEZAlhlQDMv/0Dd/83HSPeXG52GEREPoVVAgiFGgAAnDxXbHYIREQ+hVUCsO8MNubiVj5KVs2+42er9fuJiIIhrBKAXbtGCdX6/Wuyjzu9H/XOymq9HhFRdQirBNCiQS0AQOO6NYJ63U0HTrod42YxRBTqwioB3NEvGR/emYZruzU3OxQ8+OkGs0MgIvIqrCaCRUUJBnZqYnYYAIBFO46ZHQIRkVdhVQPQe/G6i8wOgYgopIVtArilTxv8bWhHs8MgIgpZYZsAAKBhfJzZIUApBWWwcUD63uNIHjcHG/afMCEqIqIwTwChsF9L75cWoc9Li9yO/7IrFwCwYndesEMiIgIQZp3AoSj39Hmv57mrGBGZJaxrAE3r1jQ7BI9cF63YuP8EBr26FGfPl5gSDxFFnrBOAAM6NTY7BJ/sG81PnrcDWblnseVgvskREVGkCOsEAAA9W9c3OwRjIbJwHRFFrrBPAGLijXZVlu8OXnsfwPmSMpfjCh+v2ou8M977EIiIKivsE0C3lubVAG5+f63j9bHThR5H/GTlnnFbT2j3sTP4+/cZ+MtnG6s1RiKKXGGfAMaP6IQf/3KZ2WHg+vdW4dYP1uLzX/fjxv+sdjq34/Bpt/JFWo0gv4B7CxBR9Qj7YaCx0VG4qEU9s8PAgeO2bSLHz9oCAEhr08Bxjt0BRGSGsK8BhKp3l2YBAA6cOBcSW1kSUeSJmARgnxPQUtszIFTM2nCwQjWAsjLOHCOiwIiYBPDOLT0BANFRofi8XR6Tt5nBCzKOoO3Tc7HrqHufARFRRUVMAqgVGw0gNJdeGPvJesfr42eLUFJaZlhuQcYRAMAHy/cEJS4iCm8RkwDszSxloZgBdB78bAMm/bDNa5kv0w9gCTecIaIq8isBiMgwEdkpIpkiMs5LuYtFpFREbghciIEV4vd/AMCMNft8ljmUXxCESIgonPlMACISDeAdAMMBdAZwk4h09lDuZQALAh1kIHCoJRGRM39qAL0BZCqlspVSRQC+ADDKoNxfAMwEEJJtE6J1tBptzhKKSrXRPoXFpY5jwgGjRBRA/iSAFgAO6N7naMccRKQFgOsATPX2RSJyv4iki0h6bm5uRWOtkjYNa6Nto3hMvMat8hKSRr2zEgCQnXfW8DxHgxJRVfmTAIweO11vP68DeEopVWpQtvxDSk1TSqUppdIaNWrkb4wBUTM2Gov/2h+XtE8K6nUDIXncHGTlnnE69v6ybJOiIaJw4c9SEDkAWunetwRwyKVMGoAvtJU3kwCMEJESpdR3AYmSsHH/Sad+jP3Hz5kXDBGFBX9qAOsApIpIiojEARgD4Ht9AaVUilIqWSmVDOAbAA/w5h94rt0XJaVl+PfCXThdyAXjiKjifCYApVQJgIdgG92zHcBXSqkMERkrImOrO8BAs0gfsCHX/QXmbDmMNxbtxsvzd5gUERFZmV+rgSql5gKY63LMsMNXKXVn1cMiV3vyzuBwfqHTMfsmMoXFxjOHiYi8iZiZwFb3zpIss0MgojATeQnAwk1Arp78ZrPTe6UU1mT/bpm5DkRkrrDfECZS/LDpEE4VFmPCt1vx5k090DA+DkkJNdCxaR2zQyOiEBVxCSC+RrTZIQTcN+tz8M36HMf7A8fP4eHPbXsJ750y0qywiCjERVwTUEx0FLZOGmp2GNWKTUBE5I+ISwAAkFAjJqyfjLcczDc7BCKygIhMAOFu/b4TZodARBbABBCGfLUAfbhiD7YfPhWcYIgoZDEBhCH9/f/txbvdbvbP/bgNw99YHtygiCjkMAEgVDeKr7zjZ4scr//10y6MenulX587V1SCrQb9B0opvLc0C8dOFRp8ioisigkAQL1asWaHUK2KPGwy7+rhz3/D1W+tcFtcbseR03h5/g48pA0tJaLwwAQQYTYdOOnx3Ib9ts7johLnhGHfnexMYUn1BUZEQccEAOMdbwBg2IVNMeOe3kGNpbokj5uDj1ftxWsLdzmOXf7KYseWkwsyjjiajsoUUKyrNdj3IeDsAqLwwgSgM/XWXo7XWS+NwHu39kRyw3gTIwqsv3+f4fT+wPEC5JywbSzzpxnrHcdv+2AtUifMc7yPEmvtp0xE/mEC0OnVpoHjdXSUQETQID7OxIjMsePIacfrfb+fRcYh2ygi+/1/yc5jyDx2xuijRGQhEbcWkBHxMggooUZ4/Yp+2ZXr9P7gyUJ8v+mwYdmJs7fif6v3Od4rrRHoro/WAeA6Q0RWF153t0qKrxGDvDNFXhNBuLrjw189ntPf/AFb34C/th7MR/vGCagZG36L7xGFCzYBAfjknj54ekQnJCXUMDuUkGbU7POHqavw4Yo9TseO5Bfi6rdW4Jnvtlb4GgeOn0N+QTEKi0vx6dp9Fe53KCgq9VkmK/cM/j57K8oqktGIwlBEJ4ChFzZBs3o10SqxNu6/op3Z4VjOyXNFWLf3BJ77cRum/lK+Y9kpbR7BbwdO4lxRCeZvPeL3d17+yhIMf30Z/r1wFyZ8uxXzKvDZJTuP4YKJ87F+33Gv5e77OB0fr96H7Lyzfn83UTiK6ATwn9vSsHr8oAp/bvmTA6ohGuvp/txCx+sp88o3pte3pD37XQbGfrLecIaxJ4fyC/G7NiT1zHnjuQdlZQqrs353OrZydx4AYMM+z3Md9CKxyY9IL6ITgCd1ahp3jXz1p374+fEr0SqxdpAjsq4Dx23DTD3dyCvr/eXZuOn9NViy81hAvzcSKaUwfXk2l/qIQEwALpY80R+//M34Cb93SiLaN04IckShpdRLu3nyuDk4ea58HSJ9+73rw/aM1Xux10sTjK+m/6xcW3/E0XzetKoqO+8sXpizHWM/We+7sJ+OnS50mkxIoYkJwEVKUjwSXcb+pyTFe6wVRJoB/1rq9fyOI6edmlaUwfzhopIyPDs7AzdMXeXxe46cKgDgeZa2PUH404xzyeRFGPrvZb4LGpj920Ekj5vjmDAXjuxJ/VQll/o4eLIA6XvL+10KikrR+8VFmPDtloDER9WHdzU/LHmiP2fBavYf934j1P+aFPQ36vLZxGXawVMFzjecEt0T48pM5/Z9t+tof4pBinBNOofyC4FK1hS+23gQALDr6Gm0bFC1pr/fz5xHzdhoxIfY3BJfOXRVVh427DuBhwamGp6/4pUlKC1TjnkhBdryIj9tO4pXPHynfbRWrTgOEzZTaP1NDGHCHsMKsP2uzheXIV3bnWzSDxnIOHQKl7RriFX2zlvtVzp9eTZemLMdvVMS/b6CKs8A5VfVXh84XoB1e4/jj/9ZjXUTBlflBwmoXi/8jBb1a2HluIGmxlFcWoZNB04iLdn59+3pIefm99cCgMcE4K1Z0JOukxaguFRxMqHJ2AREAbV+33HHAnMHTxY4jtuXk1ilG7lTVFKGc0UleP3n3QCAX/d4H76pZ3/Kj9Il5qOnzgMAZqzZh3s/TkeZ8r495pH8Qqe9EwDgfEkpBr26FMtcZkwHqgKo/53o5RcUI7+g2PDc4Xzjz1TGdxsP4vKXl+CGqauRccg2MsuMZ5viUtaoQwETAAXUv37ahSe+3uR3+c4TF/hVTimFrv9YgE/W7NPe244v1Y0C0t/s7TdTb7eZW6avRc/nbUNZP1q5By/P34GDJwqQlXsWE2dXfBJbVXSb9BO6TfrJ7fh3Gw+i3+TFWJvtvUnM7uDJAlw4cT4m/ZBheP7RL3/DEW20j2vyCzTWmUMfE0A1itTqrX4xOX/46l85cbYI50vKcKqwBM98txUHTxbgW61t/sfNhx1NEEZP1/4+uU/6YRveW1o+mc3e5LdkZ65b2c9/3Y9VmXlux7cdOoVZG3L8u6Cf7Elt51Hfv9Pc0+dx6ZTFOFtUio9W7vVZ3rX/xN9n8nNF1b8vRMahfHy6dp/XMsdOF+Lhzzf6NfubjDEBVJNFf73S7BAso9TLXXrasmz0eH4hhvz7F8exS6csdipT5uXz9mYOvao2PoyftQU3T1/rdnzEm8vx+FfltZ+S0rKgzlM4ca5iT/RT5m/HjDX7UJFn9eW7c9F54gK3SXj+OJJf6JgX4svIN1dgwrfea2GvzN+J7zcdwg+bD1U4FrJhAgiADk3c5wa0axTZ8wUqorDY83jx3dr6QweOe24H9zYy6a3FmW7H9nibf6Arc8LPJpL1+04Y3tjeWZKFuz5a59RM5Y+b31+DYa/bhq0e0mo1CzL8XxLDbvthW7/Llpx8/PWrTW5rH209eArP6tdr0p321B9h/33qh33aTV+e7bGPAwD6Tl6Ey19Z4m/4AfP0t1scTYdVtePIKZwN8KRGMzEBVFL7xglI1SaF1YjxPJQt66URwQopYg169Rffhfy0+2j5gncvz9/hdt6oQ/b691YZ3th2H7M12+Sd8f/JfOb6HKzK+t3RjLZohy15eBoWe/Z8ieOm6/ocP/yN5QCAe/+3DjM35ODY6fOG3+HaCfzbgZPoNuknzN3ivky4vaPenkv0iemFOdtx6ZTF2OCl4z2QHFuU+qjSfbZ2f6UWJnRVVFKGYa8vD+iEObMxAVTSz49fiZ8euwJjr2yHd2/p6bFcdBS7wkKdp1VB9U1LL87Zjk0HTqLEZfSKUf/Fysw8HDtdiB83226g+QXFXlcete/FDADrDJ6svbnp/TW4dMpifL/JdzOI0aQ85/O2p/h5W21xrzTo57Cz/27+Ptu9s/ne/6UDqNjQ6fyCYiSPm4OZ6937UHJPnzdcpmJ+JWpFevt/P6c1gfmWeewM7vzItnR6RUarhTomgCoQEYwb3olrA4WAqkzUu+gfxiOR9J2k2XlnMeqdlU4rmx49VYiU8XPdPnfL9LXo/eIix/vnf9yGThPnO5VZkHEEn/+6HwAw+l3jGdFjpq32GfvmHFsfx8Ofb8Ts34yTgH14rCf2n/LYqUK8MGc7/vNLtu24l/t3UWkZSsuUY0RRZV3z1grknTnvaEL7wGVpcQC4+MWf0fulRW7H7eyJ7ZnvtiB53By/r/2H/6zCs99tdQxb9mbi7K2OIczhNICVE8GqQcMI3EbSbE/N3OxXuds+cO+8PacbRfLMd+XLF3yZfsCt7Itztzteu3ZGe1NUUt7PsSozz7EH8xGXGcr6m+6abOcnzaH/XoadR0/jj2mt8PINXd2u8fYS9/4OPV85siLzud5bmoUsP7YFPXu+BDuOnEKvNsaT/LYczMfkuTtw16XJ/l/cg0/W7K9Q+ZPnjPs5fFK2zvDc0+cxumfLyn1HiGANIEBm/rmf4/Xapyu+xDRVzVfp/g2/XL7bc7MGULH2+pJKbiijH0H0xqLdfn/OPhT0y/QDSB43Bwu3HfX5Gf0aSIt3GHdGD9T6UFybiHYfPYOU8Z6fqH/y4/p3fPgrrn9vtVOHuusT90zd0NmC4lIc9VCrSB43B6//vAsb95/Alf907nMZ+OpSn7Es2n4UWw/mQymFez9eh/Ml/i9W55o8b/vgV6cRX7dOX4v7tKYvvTs/+hWv/7zL7+sEG2sAAaJ/womJZl4NRWkv/Gx2CD55GxHlyuiG40o/f8BXR6jrtddWsa37+Nkix2QzfT/HC3O2uZW1dybvyTuLPl6ae17/ebdj5rhedq7xyC796rT3fGz7fS19oj9+3l6eDM8VlbptXfrWot24LDUJPVo3AACs1k3EM+pLWaH1l5SUljn9+1+6MxdLd+bi0cEdPP5MZuKdyiQPDWhvdggRJ++M97bwUGCf4BaKHvvyt0p/1n7zBYybaoyG6wLOCwR64vp0/u3G8hqFftMiO9fJej2fX4jpy7Odjr26cBeu89A3463id6KyzUomYQIwCVdBJKsxIzld/dYKn2Vcb7qPfel9KRKjG/gLc7ZDKYVVWbYRXHbPfLfFbc5DZRa/C1V+NQGJyDAAbwCIBjBdKTXF5fwtAJ7S3p4B8GellP8LwoSJ+Y9ejp0GyyBs+cdVmLvlMJ6aWd7B2Lx+zWCGRmRJ/iwrYjRfAwBWeOjvMZrfANhqEvaVT+0+WbMfn62tWOeylfisAYhINIB3AAwH0BnATSLS2aXYHgBXKqW6AngewLRAB2oFnZrWxajuLdyO16kZiz4pDZ2P1Yg1XCuoW6v61RYfUaR48NMNuNVgxBdgG9JrZLWHBfcq8rxvtVXj/akB9AaQqZTKBgAR+QLAKACOnhyllL6xbA0Aa4+NqgZ1a8U6vbfaXxQiK5nj4Snfm1sM1ncCKrcU+HM/bMOHK93nNIQaf/oAWgDQD4jO0Y55cg+AeUYnROR+EUkXkfTcXPdVFsNZYnwcljzRHw8OaAcAaFKXTUBEVpd5zLmJ6uS5Ymzcf8ISN3/AvxqA0bOqYU4UkQGwJYDLjM4rpaZBax5KS0sLn54UP6UkxePxIR0xsktzdG5e1+xwiKiKBr+2DF/c31f33nhdqrIyBZHQ21nQnxpADoBWuvctAbjNOReRrgCmAxillKr4WrERIjpK/Lr5JyVwNjGRFYyZtsZnmbZPz8WkH7bh9xAbiuxPAlgHIFVEUkQkDsAYAN/rC4hIawCzANymlArdaW8W8Mig9ogS20b0RBQ+/rtqL3qF2GREnwlAKVUC4CEACwBsB/CVUipDRMaKyFit2EQADQG8KyK/iYjvKYoEAPjwzjSn9wM7NUH25JGoUzPWwyeIyMqqsnBhoPk1D0ApNRfAXJdjU3Wv7wVwb2BDiwwDOzUxOwQiCqLUCfOQGSL7hHAmcAj77L4+ZodARAFW2UUEqwMTQAh5/v8ucnp/Sbskp/f3XZ7i9pn0Zwbjb0M7VmtcRBSemABCyG1923g9//SIC5ze35jWEkkJNTCgY+PqDIuIwhQTQAhYPX4gNjw7xPDcjueHOV6LCC5PTTIsR0TWse3QKbNDAMAEEBKa1auFRA+7iLmuUz7jnj54+fouTsc4Z4DIWka8udzsEAAwAViSPSnE17AN4mpctyZu6OV7+aUb07hEExGV445gFnR11+Y4nF+I2/uV9xk0qVvD5+d8zS3o0qIethzMr3J8RGQNTAAWEBcThREXNXW8j44SjL2ynVMZ0S3ZVCs2GgUu+64Cvlc1DLFlSoiomjEBWMCuF4ZXqLzRnqUAUOYjA1T0/l8zNqpCe9gSUWhhH0CYsD+9j+jS1GMZn1PQfVQBBnZyHm7aqI5xs5On40QUWpgAwsTdl6ZgSOcmeOm6Lh6benz1AfiqAegTyMw/98Mr13cDAFzYvC4eHpTqOLduwmC/YiYiczEBhIkG8XF4//Y01K8dh+SG8U7nHhrQHnsmj0DNWO//u683GEm0d8pItwXrAKBXm0TU1ja2jxJBu0bxbmUAoE7N8lbG9c8wMRCFEiaAMPTJvX0w/fbym/YDA9pBRHDnpSkY3bMF7rwk2fBzXVvUMzzeoUkdAMC13Zsbnven8zguJgoNE4LXNHRNN+NYzfDo4FTfhYhMwAQQhhrVqYHBnd1XGU2oEYPXbuzu9FR+U+/WjtcNapdPKOvUtA4eGtAeANCyQW3snTIS1/Xwfx7BxKs72164NEc1q2fbCvOKDo0MP3f3pe7rHVVmdNJbN/VwSoJ2jw3uUPEvq6K6XNqbQhQTQASIjTb+3/z4kA6YPLoLFj52BSaP7oLWDWs7zn1w58V4wmCRuceHdHAkjQub18Xoni3w2o3d3crZn8Dt93/7PbxeLdvN8Klh5d/9iK7/IDmpPAa7Hx4y3GHUp8Gdm6BB7fKb7xtjuuORKj6Nv31zjwp/xtfoKwC4tW9rp/f/vKFrha/jzQSXdaSIACaAsDf2ynYeE4BdapM6jpu6/Qndk4cHpWLyaNtSFDHRUXjtxu5o3zgBqY3rGJavEROFG9Na4tN7bUtb25tD2uj6KYboaiujurdAz9b1nb7jQt0WmjFR5dUBfd/Er08PwuALbN/zQP/yORIbJ16FrJdGYOqtvXCtrlno0vYN8ZeB7Z2u07N1fY/Lanwzth/+d3dvXN3VvWnJVw1lZNdmPkdGPT/qIkdyHNK5CQZ08r7A3+geLbxf1MUfe7fyXYgiDhNAGNs7ZSTGDe/kdvzmPoSFEBMAAAxzSURBVK3RoUkCbkxzvylEVXI2mOs+x/YO577tGuKVG7ohLTkRADDsombYO2UkEmqUN0PZr9mpaR3UqxWLWQ9cih5aEoiPi4aI4LP7+uBPV7TFb3+/yvE5/WY6jevWxLTbeuGVG7riUZdmnugowbCLmjo25N71wnD87+4+SNX6Njo2qYMfHroMsx641OPPl5acaNhs9behHZH90gh8dNfFyHppBD6682IAwF2XJjvKNEqogTXjBwGwJYt5j1yOGff0dvoeEcHHd9uOje7RAkkJNRw1n0vaNXS7bpIuoVzo8rvfO2Wk4/WYi1vh/dvTDJuh+ncs/3n2THbfoMRee2qVWMtxrK2Hzn6yJk4Ei0DN6tXCT49daXju47t74+v0A2juoybgS52asfj58SvRskEtn2VraMmioe7pe/LoLhj2+nLMe+QKALa9Eez7I2x8dgiKy9wnoEVFiWFScxUXY7terzYNAAATRl6ALi1tHeBTRnfFlPk70LFJHSTGx2HGmn1ev+uars0hIo4luQd0auy4AX+8ai/se3/Y06pSwAXNnG/Y79zcEwDQvVV9bP7HVY6bdZeW9RzflZ17BiuzfscVqUlonVgbk+ftcHz+s3v7ottzPwEArnLp+5lyfXlTkkj5bPB+bRviozsvRsr4udo5wbTbeuH+Gesd5TdOLE+2yePmAAAWPHoFUifMc/s9vHx9Fzw1cwuu69EC32486O1XViHbnhuKzhMXBOz7yBkTADlp3zgB4wPUXty+cYJf5do1SsDL13fBkM7lk9g6Na3r9CSr18DDyqkV1aJ+LbdrDO7cxKkDPfPYGVyckujxO/T9Jp6IiGEzUc3YKKQkJWBk12aOY546jNs2SkDbRu6/z/HDO6Fe7VjHrOxpBh3fdl1a1MPmHNtaTxNGXuCoEdlddWFTREcJSssUkj38XLHRUfjsvj64+f21Hs5XfT2Rvwxsj7cWZwIA9Jtnje7ZArM2lCeXbi3rYVOO8dpVwy9qinlbjzjeP9C/HVZl/Y7fDpyscnzhhE1AFBL+eHFrj0tie/PGmO54btSF1RCRzef398XjQ4xHDnmbda2nlHLcbDs1Le8r2fH8cMx75PKqBwlg/TNDsOUfV3ktM0TrI1kzfhAu0ob82vsd7OwjxD7QmrKMuO5U59os9OofuuHOS5INR2H5469XlQ8Q0Hegv3Zjd6chtROv6ez0uSd1AwtSdQ8f3VrWw5PDOuG7Bz038QHwmPTCGWsAZJrP7+uLk+eKqvQdo7pXrDM0kN69pZfX87MfvAw/bjmEaK3j+sv7+zr6HQItvobvf8oPDmiPW/u2capBrXhqAIpKypvT7PfbxNoVS8b2n+vi5ERc36ulY1LhgkevwPmSUlz79koAtlph5rEzaJVYC8ufHOhoWgKAV27oijaJtptwfFy00xDlxlqfxyODUnF11+b476o96N6qAX5+/AoMfm0ZWifWdjSt3dCrpSPBAcBsP0eRudaIqlPftp5rlcHEBEABc3u/NjhVUOx3+X4GnZtWMKJLU4zs4nuiWZeW9Rx9CwDQp23gft4rUhth2rJsj81TD/Rv57a0d1SUuDWfuS4PUis2GvkFxW6DASZe3Rk7j5w2vFbHJnXQs3UDrBk/CE1d+o46ajUee1PbsdOFGDtjPT6607kTfO3Tg9CkbvlnM54r3wnvqWGdMPRCW+1FRNC+cQJe+D/bSDR7k9mFzeuif4dG+Ps1nXFjWivE14jB8icHoFWi96f6J67qgH/9tMvxfttzQ1EjJhrtnp7rOLb8yQG46f01yDlR4PTZAR0bYcnOXKdjG54dgp7PL/R6TaDygy0CTXwuEFZN0tLSVHp6uinXJgoHxaVlPof4VtSevLOYt/UwHujf3ms5+5P7Z/f2QbdW9f2qgRj5Zn0OLmpRF52a1vVd2IP0vcfRuXld1I7zHYM97m/G9sMna/bh1Ru7O272KUnxWPJEfwBA5rHTGPzaMiTGx2HDs0PwVfoBPPnNZgDA4AsaY8r1XbFuz3H8+dMNju9u3zgBX/2pn18J4I0x3StdexWR9UqpyrWvuWANgMiiAn3zB2w3QV83f8DWtFOvVqzbE39F+bOTnS/2IcYV/Yz9c5e1T8KKzDyn/pl6tWw1Jftz+o1prZB7+jz+uWAnUpvUQVJCDbRzGeTQskEtjyvuLn9yAG79YC32/X4OM/98iWMEmtmYAIiowjo2rZ6+jOr2zEj3EW4PD0rFisw8DL2wvFO/Xq1YRAkM59HYk0KHJnWw4dkhOFVQjIxDp3BZahJKtWFLUWIbwXRT71aYPNo2FPfNMT0w8fsMt3kbZmICIKKIce/lbd2O9U5JxOrxA9GsXvmclbiYKGRPdh4ibJ+Frq95JcbHITE+DslJtpFQSik8OjgV13RrDqUUWieWj5Dq1qo+ZvsYiRRsTABEFPH0N39P7rgkGcfPFrltx6onIm4z0UMZEwARkR9qxkYHbJJkqOBEMCKiCMUEQEQUoZgAiIgiFBMAEVGEYgIgIopQTABERBGKCYCIKEIxARARRSjTVgMVkVwA3vfb8ywJQF4AwwkWK8bNmIODMQdHOMTcRinlvkF1JZiWAKpCRNIDtRxqMFkxbsYcHIw5OBizMzYBERFFKCYAIqIIZdUEMM3sACrJinEz5uBgzMHBmHUs2QdARERVZ9UaABERVRETABFRhLJcAhCRYSKyU0QyRWScybG0EpElIrJdRDJE5BHteKKILBSR3dqfDXSfGa/FvlNEhuqO9xKRLdq5N0VEjK4ZoLijRWSjiPxohXi169UXkW9EZIf2++4X6nGLyGPa34utIvK5iNQMtZhF5EMROSYiW3XHAhajiNQQkS+142tFJLmaYv6n9ndjs4h8KyL1Qz1m3bknRESJSFLQY1ZKWeY/ANEAsgC0BRAHYBOAzibG0wxAT+11HQC7AHQG8AqAcdrxcQBe1l531mKuASBF+1mitXO/AugH257T8wAMr8a4HwfwGYAftfchHa92vY8B3Ku9jgNQP5TjBtACwB4AtbT3XwG4M9RiBnAFgJ4AtuqOBSxGAA8AmKq9HgPgy2qK+SoAMdrrl60Qs3a8FYAFsE2KTQp2zNX2D7aa/lH1A7BA9348gPFmx6WLZzaAIQB2AmimHWsGYKdRvNr/+H5amR264zcB+E81xdgSwCIAA1GeAEI2Xu3768J2MxWX4yEbN2wJ4ACARNi2Xv1Ru0mFXMwAkuF8Mw1YjPYy2usY2Ga0SqBjdjl3HYBPrRAzgG8AdAOwF+UJIGgxW60JyP6Pyi5HO2Y6rcrVA8BaAE2UUocBQPuzsVbMU/wttNeux6vD6wCeBFCmOxbK8QK2Gl8ugI+0pqvpIhIfynErpQ4C+BeA/QAOA8hXSv0UyjHrBDJGx2eUUiUA8gE0rLbIbe6G7enY6fousZkes4hcC+CgUmqTy6mgxWy1BGDU9mn6OFYRSQAwE8CjSqlT3ooaHFNejgeUiFwN4JhSar2/HzE4FrR4dWJgqz6/p5TqAeAsbE0Tnpget9ZuPgq2KnxzAPEicqu3j3iILZT+zlcmxqDGLyITAJQA+NTH9U2NWURqA5gAYKLRaQ/XD3jMVksAObC1mdm1BHDIpFgAACISC9vN/1Ol1Czt8FERaaadbwbgmHbcU/w52mvX44F2KYBrRWQvgC8ADBSRT0I4XrscADlKqbXa+29gSwihHPdgAHuUUrlKqWIAswBcEuIx2wUyRsdnRCQGQD0Ax6sjaBG5A8DVAG5RWltICMfcDraHg03av8eWADaISNNgxmy1BLAOQKqIpIhIHGydHd+bFYzWA/8BgO1Kqdd0p74HcIf2+g7Y+gbsx8doPfYpAFIB/KpVs0+LSF/tO2/XfSZglFLjlVItlVLJsP3uFiulbg3VeHVxHwFwQEQ6aocGAdgW4nHvB9BXRGpr1xoEYHuIx2wXyBj133UDbH/nqqN2OwzAUwCuVUqdc/lZQi5mpdQWpVRjpVSy9u8xB7YBJUeCGnNVOzaC/R+AEbCNtskCMMHkWC6DrZq1GcBv2n8jYGt7WwRgt/Znou4zE7TYd0I3mgNAGoCt2rm3EYBOJx+x90d5J7AV4u0OIF37XX8HoEGoxw1gEoAd2vVmwDaqI6RiBvA5bH0UxbDdhO4JZIwAagL4GkAmbCNY2lZTzJmwtYHb/x1ODfWYXc7vhdYJHMyYuRQEEVGEsloTEBERBQgTABFRhGICICKKUEwAREQRigmAiChCMQEQEUUoJgAiogj1/5BmlLCALiXWAAAAAElFTkSuQmCC\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "plt.plot(loss_value)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'net' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mfile_path\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"./net_trainned6000_MB_102\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0msave_net\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfile_path\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mnet\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;31mNameError\u001b[0m: name 'net' is not defined" + ] + } + ], + "source": [ + "file_path=\"./net_trainned6000_MB_102\"\n", + "save_net(file_path,net)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.5" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/Apprentissage_MSELoss_avec_GPU.ipynb b/Apprentissage_MSELoss_avec_GPU.ipynb deleted file mode 100755 index 92e3073..0000000 --- a/Apprentissage_MSELoss_avec_GPU.ipynb +++ /dev/null @@ -1,505 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "#Tous les codes sont basés sur l'environnement suivant\n", - "#python 3.7\n", - "#opencv 3.1.0\n", - "#pytorch 1.4.0\n", - "\n", - "import torch\n", - "from torch.autograd import Variable\n", - "import torch.nn as nn\n", - "import torch.nn.functional as F\n", - "import cv2\n", - "import matplotlib.pyplot as plt\n", - "import numpy as np\n", - "import random\n", - "import math\n", - "import pickle\n", - "import random\n", - "from PIL import Image\n", - "import sys" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "#Les fonctions dans ce bloc ne sont pas utilisées par le réseau, mais certaines fonctions d'outils\n", - "\n", - "\n", - "def tensor_imshow(im_tensor,cannel):\n", - " b,c,h,w=im_tensor.shape\n", - " if c==1:\n", - " plt.imshow(im_tensor.squeeze().detach().numpy())\n", - " else:\n", - " plt.imshow(im_tensor.squeeze().detach().numpy()[cannel,:])\n", - "\n", - "# Obtenez des données d'entraînement\n", - "# frag,vt=get_training_fragment(frag_size,image)\n", - "# frag est un patch carrée de taille (frag_size*frag_size) a partir du image(Son emplacement est aléatoire)\n", - "# vt est la vérité terrain de la forme Dirac.\n", - "def get_training_fragment(frag_size,im):\n", - " h,w,c=im.shape\n", - " n=random.randint(0,int(h/frag_size)-1)\n", - " m=random.randint(0,int(w/frag_size)-1) \n", - " shape=frag_size/4\n", - " vt_h=math.ceil((h+1)/shape)\n", - " vt_w=math.ceil((w+1)/shape)\n", - " vt=np.zeros([vt_h,vt_w])\n", - " vt_h_po=round((vt_h-1)*(n*frag_size/(h-1)+(n+1)*frag_size/(h-1))/2)\n", - " vt_w_po=round((vt_w-1)*(m*frag_size/(w-1)+(m+1)*frag_size/(w-1))/2)\n", - " vt[vt_h_po,vt_w_po]=1\n", - " vt = np.float32(vt)\n", - " vt=torch.from_numpy(vt.reshape(1,1,vt_h,vt_w))\n", - " \n", - " return im[n*frag_size:(n+1)*frag_size,m*frag_size:(m+1)*frag_size,:],vt\n", - "\n", - "# Cette fonction convertit l'image en variable de type Tensor.\n", - "# Toutes les données de calcul du réseau sont de type Tensor\n", - "# Img.shape=[Height,Width,Channel]\n", - "# Tensor.shape=[Batch,Channel,Height,Width]\n", - "def img2tensor(im):\n", - " im=np.array(im,dtype=\"float32\")\n", - " tensor_cv = torch.from_numpy(np.transpose(im, (2, 0, 1)))\n", - " im_tensor=tensor_cv.unsqueeze(0)\n", - " return im_tensor\n", - "\n", - "# Trouvez les coordonnées de la valeur maximale dans une carte de corrélation\n", - "# x,y=show_coordonnee(carte de corrélation)\n", - "def show_coordonnee(position_pred):\n", - " map_corre=position_pred.squeeze().detach().numpy()\n", - " h,w=map_corre.shape\n", - " max_value=map_corre.max()\n", - " coordonnee=np.where(map_corre==max_value)\n", - " return coordonnee[0].mean()/h,coordonnee[1].mean()/w\n", - "\n", - "# Filtrer les patchs en fonction du nombre de pixels noirs dans le patch\n", - "# Si seuls les pixels non noirs sont plus grands qu'une certaine proportion(seuillage), revenez à True, sinon False\n", - "def test_fragment32_32(frag,seuillage):\n", - " a=frag[:,:,0]+frag[:,:,1]+frag[:,:,2]\n", - " mask = (a == 0)\n", - " arr_new = a[mask]\n", - " if arr_new.size/a.size<=(1-seuillage):\n", - " return True\n", - " else:\n", - " return False\n", - " \n", - "# Ces deux fonctions permettent de sauvegarder le réseau dans un fichier\n", - "# ou de load le réseau stocké à partir d'un fichier\n", - "def save_net(file_path,net):\n", - " pkl_file = open(file_path, 'wb')\n", - " pickle.dump(net,pkl_file)\n", - " pkl_file.close()\n", - "def load_net(file_path): \n", - " pkl_file = open(file_path, 'rb')\n", - " net= pickle.load(pkl_file)\n", - " pkl_file.close()\n", - " return net" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [], - "source": [ - "# Les fonctions de ce bloc sont utilisées pour construire le réseau\n", - "\n", - "# Créer un poids de type DeepMatch comme valeur initiale de Conv1 (non obligatoire)\n", - "def ini():\n", - " kernel=torch.zeros([8,3,3,3])\n", - " array_0=np.array([[1,2,1],[0,0,0],[-1,-2,-1]],dtype='float32')\n", - " array_1=np.array([[2,1,0],[1,0,-1],[0,-1,-2]],dtype='float32')\n", - " array_2=np.array([[1,0,-1],[2,0,-2],[1,0,-1]],dtype='float32')\n", - " array_3=np.array([[0,-1,-2],[1,0,-1],[2,1,0]],dtype='float32')\n", - " array_4=np.array([[-1,-2,-1],[0,0,0],[1,2,1]],dtype='float32')\n", - " array_5=np.array([[-2,-1,0],[-1,0,1],[0,1,2]],dtype='float32')\n", - " array_6=np.array([[-1,0,1],[-2,0,2],[-1,0,1]],dtype='float32')\n", - " array_7=np.array([[0,1,2],[-1,0,1],[-2,-1,0]],dtype='float32')\n", - " for i in range(3):\n", - " kernel[0,i,:]=torch.from_numpy(array_0)\n", - " kernel[1,i,:]=torch.from_numpy(array_1)\n", - " kernel[2,i,:]=torch.from_numpy(array_2)\n", - " kernel[3,i,:]=torch.from_numpy(array_3)\n", - " kernel[4,i,:]=torch.from_numpy(array_4)\n", - " kernel[5,i,:]=torch.from_numpy(array_5)\n", - " kernel[6,i,:]=torch.from_numpy(array_6)\n", - " kernel[7,i,:]=torch.from_numpy(array_7)\n", - " return torch.nn.Parameter(kernel,requires_grad=True) \n", - "\n", - "# Calculer le poids initial de la couche convolutive add\n", - "# n, m signifie qu'il y a n * m sous-patches dans le patch d'entrée\n", - "# Par exemple, le patch d'entrée est 16 * 16, pour les patchs 4 * 4 de la première couche, n = 4, m = 4\n", - "# pour les patchs 8 * 8 de la deuxième couche, n = 2, m = 2\n", - "def kernel_add_ini(n,m):\n", - " input_canal=int(n*m)\n", - " output_canal=int(n/2)*int(m/2)\n", - " for i in range(int(n/2)):\n", - " for j in range(int(m/2)):\n", - " kernel_add=np.zeros([1,input_canal],dtype='float32')\n", - " kernel_add[0,i*2*m+j*2]=1\n", - " kernel_add[0,i*2*m+j*2+1]=1\n", - " kernel_add[0,(i*2+1)*m+j*2]=1\n", - " kernel_add[0,(i*2+1)*m+j*2+1]=1\n", - " if i==0 and j==0:\n", - " add=torch.from_numpy(kernel_add.reshape(1,input_canal,1,1))\n", - " else:\n", - " add_=torch.from_numpy(kernel_add.reshape(1,input_canal,1,1))\n", - " add=torch.cat((add,add_),0)\n", - " return torch.nn.Parameter(add,requires_grad=False) \n", - "\n", - "# Calculer le poids initial de la couche convolutive shift\n", - "# shift+add Peut réaliser l'étape de l'agrégation\n", - "# Voir ci-dessus pour les paramètres n et m. \n", - "# Pour des étapes plus détaillées, veuillez consulter mon rapport de stage\n", - "def kernel_shift_ini(n,m):\n", - " input_canal=int(n*m)\n", - " output_canal=int(n*m)\n", - " \n", - " kernel_shift=torch.zeros([output_canal,input_canal,3,3])\n", - " \n", - " array_0=np.array([[1,0,0],[0,0,0],[0,0,0]],dtype='float32')\n", - " array_1=np.array([[0,0,1],[0,0,0],[0,0,0]],dtype='float32')\n", - " array_2=np.array([[0,0,0],[0,0,0],[1,0,0]],dtype='float32')\n", - " array_3=np.array([[0,0,0],[0,0,0],[0,0,1]],dtype='float32')\n", - " \n", - " kernel_shift_0=torch.from_numpy(array_0)\n", - " kernel_shift_1=torch.from_numpy(array_1)\n", - " kernel_shift_2=torch.from_numpy(array_2)\n", - " kernel_shift_3=torch.from_numpy(array_3)\n", - " \n", - " \n", - " for i in range(n):\n", - " for j in range(m):\n", - " if i==0 and j==0:\n", - " kernel_shift[0,0,:]=kernel_shift_0\n", - " else:\n", - " if i%2==0 and j%2==0:\n", - " kernel_shift[i*m+j,i*m+j,:]=kernel_shift_0\n", - " if i%2==0 and j%2==1:\n", - " kernel_shift[i*m+j,i*m+j,:]=kernel_shift_1\n", - " if i%2==1 and j%2==0:\n", - " kernel_shift[i*m+j,i*m+j,:]=kernel_shift_2\n", - " if i%2==1 and j%2==1:\n", - " kernel_shift[i*m+j,i*m+j,:]=kernel_shift_3\n", - " \n", - " return torch.nn.Parameter(kernel_shift,requires_grad=False) \n", - "\n", - "# Trouvez le petit patch(4 * 4) dans la n ème ligne et la m ème colonne du patch d'entrée\n", - "# Ceci est utilisé pour calculer la convolution et obtenir la carte de corrélation\n", - "def get_patch(fragment,psize,n,m):\n", - " return fragment[:,:,n*psize:(n+1)*psize,m*psize:(m+1)*psize]\n", - "###################################################################################################################\n", - "class Net(nn.Module):\n", - " def __init__(self,frag_size,psize):\n", - " super(Net, self).__init__()\n", - " \n", - " h_fr=frag_size\n", - " w_fr=frag_size\n", - " \n", - " n=int(h_fr/psize) # n*m patches dans le patch d'entrée\n", - " m=int(w_fr/psize)\n", - " \n", - " self.conv1 = nn.Conv2d(3,8,kernel_size=3,stride=1,padding=1)\n", - " # Si vous souhaitez initialiser Conv1 avec les poids de DeepMatch, exécutez la ligne suivante\n", - " # self.conv1.weight=ini()\n", - " self.Relu = nn.ReLU(inplace=True)\n", - " self.maxpooling=nn.MaxPool2d(3,stride=2, padding=1)\n", - " \n", - " self.shift1=nn.Conv2d(n*m,n*m,kernel_size=3,stride=1,padding=1)\n", - " self.shift1.weight=kernel_shift_ini(n,m)\n", - " self.add1 = nn.Conv2d(n*m,int(n/2)*int(m/2),kernel_size=1,stride=1,padding=0)\n", - " self.add1.weight=kernel_add_ini(n,m)\n", - " \n", - " n=int(n/2)\n", - " m=int(m/2)\n", - " if n>=2 and m>=2:# Si n=m=1,Notre réseau n'a plus besoin de plus de couches pour agréger les cartes de corrélation\n", - " self.shift2=nn.Conv2d(n*m,n*m,kernel_size=3,stride=1,padding=1)\n", - " self.shift2.weight=kernel_shift_ini(n,m)\n", - " self.add2 = nn.Conv2d(n*m,int(n/2)*int(m/2),kernel_size=1,stride=1,padding=0)\n", - " self.add2.weight=kernel_add_ini(n,m)\n", - " \n", - " n=int(n/2)\n", - " m=int(m/2)\n", - " if n>=2 and m>=2:\n", - " self.shift3=nn.Conv2d(n*m,n*m,kernel_size=3,stride=1,padding=1)\n", - " self.shift3.weight=kernel_shift_ini(n,m)\n", - " self.add3 = nn.Conv2d(n*m,int(n/2)*int(m/2),kernel_size=1,stride=1,padding=0)\n", - " self.add3.weight=kernel_add_ini(n,m)\n", - " \n", - " def get_descripteur(self,img,using_cuda):\n", - " # Utilisez Conv1 pour calculer le descripteur,\n", - " descripteur_img=self.Relu(self.conv1(img))\n", - " b,c,h,w=descripteur_img.shape\n", - " couche_constante=0.5*torch.ones([1,1,h,w])\n", - " if using_cuda:\n", - " couche_constante=couche_constante.cuda()\n", - " # Ajouter une couche constante pour éviter la division par 0 lors de la normalisation\n", - " descripteur_img=torch.cat((descripteur_img,couche_constante),1)\n", - " # la normalisation\n", - " descripteur_img_norm=descripteur_img/torch.norm(descripteur_img,dim=1)\n", - " return descripteur_img_norm\n", - " \n", - " def forward(self,img,frag,using_cuda):\n", - " psize=4\n", - " # Utilisez Conv1 pour calculer le descripteur,\n", - " descripteur_input1=self.get_descripteur(img,using_cuda)\n", - " descripteur_input2=self.get_descripteur(frag,using_cuda)\n", - " \n", - " b,c,h,w=frag.shape\n", - " n=int(h/psize)\n", - " m=int(w/psize)\n", - " \n", - " #######################################\n", - " # Calculer la carte de corrélation par convolution pour les n*m patchs plus petit.\n", - " for i in range(n):\n", - " for j in range(m):\n", - " if i==0 and j==0:\n", - " map_corre=F.conv2d(descripteur_input1,get_patch(descripteur_input2,psize,i,j),padding=2)\n", - " else:\n", - " a=F.conv2d(descripteur_input1,get_patch(descripteur_input2,psize,i,j),padding=2)\n", - " map_corre=torch.cat((map_corre,a),1)\n", - " ########################################\n", - " # Étape de polymérisation\n", - " map_corre=self.maxpooling(map_corre)\n", - " map_corre=self.shift1(map_corre)\n", - " map_corre=self.add1(map_corre)\n", - " \n", - " #########################################\n", - " # Répétez l'étape d'agrégation jusqu'à obtenir le graphique de corrélation du patch d'entrée\n", - " n=int(n/2)\n", - " m=int(m/2)\n", - " if n>=2 and m>=2:\n", - " map_corre=self.maxpooling(map_corre)\n", - " map_corre=self.shift2(map_corre)\n", - " map_corre=self.add2(map_corre)\n", - " \n", - " \n", - " n=int(n/2)\n", - " m=int(m/2)\n", - " if n>=2 and m>=2:\n", - " map_corre=self.maxpooling(map_corre)\n", - " map_corre=self.shift3(map_corre)\n", - " map_corre=self.add3(map_corre)\n", - " \n", - " \n", - " b,c,h,w=map_corre.shape\n", - " # Normalisation de la division par maximum\n", - " map_corre=map_corre/(map_corre.max())\n", - " # Normalisation SoftMax\n", - " #map_corre=(F.softmax(map_corre.reshape(1,1,h*w,1),dim=2)).reshape(b,c,h,w)\n", - " return map_corre" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "def run_net(net,img,frag,frag_size,using_cuda):\n", - " h,w,c=frag.shape\n", - " n=int(h/frag_size)\n", - " m=int(w/frag_size)\n", - " frag_list=[]\n", - " #####################################\n", - " # Obtenez des patchs carrés des fragments et mettez-les dans la frag_list\n", - " for i in range(n):\n", - " for j in range(m):\n", - " frag_32=frag[i*frag_size:(i+1)*frag_size,j*frag_size:(j+1)*frag_size]\n", - " if test_fragment32_32(frag_32,0.6):\n", - " frag_list.append(frag_32)\n", - " img_tensor=img2tensor(img)\n", - " ######################################\n", - " if using_cuda:\n", - " img_tensor=img_tensor.cuda()\n", - " \n", - " coordonnee_list=[]\n", - " #######################################\n", - " # Utilisez le réseau pour calculer les positions de tous les patchs dans frag_list[]\n", - " # Mettez le résultat du calcul dans coordonnee_list[]\n", - " for i in range(len(frag_list)):\n", - " frag_tensor=img2tensor(frag_list[i])\n", - " if using_cuda:\n", - " frag_tensor=frag_tensor.cuda()\n", - " res=net.forward(img_tensor,frag_tensor,using_cuda)\n", - " if using_cuda:\n", - " res=res.cpu()\n", - " po_h,po_w=show_coordonnee(res)\n", - " coordonnee_list.append([po_h,po_w])\n", - " h_img,w_img,c=img.shape\n", - " position=[]\n", - " for i in range(len(coordonnee_list)):\n", - " x=int(round(h_img*coordonnee_list[i][0]))\n", - " y=int(round(w_img*coordonnee_list[i][1]))\n", - " position.append([x,y])\n", - " return position" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [], - "source": [ - "if __name__=='__main__':\n", - " \n", - " # La taille du patch d'entrée est de 16*16\n", - " frag_size=16\n", - " # La taille du plus petit patch dans réseau est de 4 *4 fixée\n", - " psize=4\n", - " using_cuda=True\n", - " \n", - " \n", - " net=Net(frag_size,psize)\n", - " \n", - " # Pour chaque fresque, le nombre d'itérations est de 1000\n", - " itera=1000\n", - " \n", - " if using_cuda:\n", - " net=net.cuda()\n", - " \n", - " # Choisissez l'optimiseur et la fonction de coût\n", - " optimizer = torch.optim.Adam(net.parameters())\n", - " loss_func = torch.nn.MSELoss()\n", - " \n", - " # Dans le processus d'apprentissage du réseau,le changement d'erreur est placé dans loss_value=[] \n", - " # et le changement de Conv1 poids est placé dans para_value[]\n", - " loss_value=[]\n", - " para_value=[]\n", - " ####################################################training_net\n", - " \n", - " #Les données d'entraînement sont 6 fresques\n", - " for n in range(6):\n", - " im_path=\"./fresque\"+str(n)+\".ppm\"\n", - " img_training=cv2.imread(im_path)\n", - " h,w,c=img_training.shape\n", - " \n", - " # Si la peinture murale est trop grande, sous-échantillonnez-la et rétrécissez-la\n", - " while h*w>(1240*900):\n", - " img_training=cv2.resize(img_training,(int(h/2),int(w/2)),interpolation=cv2.INTER_CUBIC)\n", - " h,w,c=img_training.shape\n", - " im_tensor=img2tensor(img_training)\n", - " \n", - " if using_cuda:\n", - " im_tensor=im_tensor.cuda()\n", - " for i in range(itera):\n", - " # Tous les 100 cycles, enregistrez le changement de poids\n", - " if i%100==0:\n", - " para=net.conv1.weight\n", - " para=para.detach().cpu()\n", - " para_value.append(para)\n", - " frag,vt=get_training_fragment(frag_size,img_training)\n", - " frag_tensor=img2tensor(frag)\n", - " if using_cuda:\n", - " vt=vt.cuda()\n", - " frag_tensor=frag_tensor.cuda()\n", - " # Utilisez des patchs et des fresques de données d'entraînement pour faire fonctionner le réseau\n", - " frag_pred=net.forward(im_tensor,frag_tensor,using_cuda)\n", - " b,c,h,w=vt.shape\n", - " # Utilisez la fonction de coût pour calculer l'erreur\n", - " err_=loss_func(vt,frag_pred)\n", - " # Utilisez l'optimiseur pour ajuster le poids de Conv1\n", - " optimizer.zero_grad()\n", - " err_.backward(retain_graph=True)\n", - " optimizer.step()\n", - " \n", - " loss_value.append(err_.tolist())\n", - " \n", - " del frag_tensor,frag_pred,err_,vt\n", - " torch.cuda.empty_cache()" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "6000" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "len(loss_value)" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "[]" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - }, - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO3deXxU1fk/8M+TBAhbQCAsQiCArCKbYXFBXIqC9FvcFat1Q0q/ot8Wa4uK1qL8RKxbKxYRrFoEqlVEDYuyo4AQAoEAiYQlJGFJSEhICNnP74+5mcxyZ+ZmMtudfN6vFy9m7j1z55wQnjlz7jnPEaUUiIjI/CKCXQEiIvINBnQiojDBgE5EFCYY0ImIwgQDOhFRmIgK1ht36NBBxcfHB+vtiYhMaffu3WeVUrF654IW0OPj45GUlBSstyciMiURyXR1jkMuRERhggGdiChMMKATEYUJBnQiojBhKKCLyHgRSReRDBGZqXP+ehEpEpG92p8XfV9VIiJyx+MsFxGJBDAfwDgA2QB2icjXSqmDDkW3KqV+6Yc6EhGRAUZ66CMBZCiljiqlKgAsBzDJv9UiIqL6MhLQuwLIsnmerR1zdJWIpIjIahG53Ce105F+uhhvfJeOsyXl/noLIiJTMhLQReeYYxL1ZAA9lFJDAPwDwFe6FxKZKiJJIpKUl5dXv5pqMnJL8I8NGSi4UOHV64mIwpWRgJ4NIM7meTcAJ20LKKXOK6VKtMerADQRkQ6OF1JKLVRKJSilEmJjdVeueiTax0sNN+YgIrJjJKDvAtBHRHqKSFMA9wH42raAiHQWsYRaERmpXTff15UFgAgtoDOeExHZ8zjLRSlVJSLTAawFEAngQ6XUARGZpp1fAOAuAL8TkSoAFwHcp/y2t50lorOHTkRkz1ByLm0YZZXDsQU2j98F8K5vq6aPPXQiIn2mWymqjewwoBMROTBdQLf20J0m2hARNW6mC+h1s1yCWw8iolBjvoCO2iEXRnQiIlvmC+jWIRciIrJluoBeu0J09f5TQa4JEVFoMV1A//lMCQDg420ut9UjImqUTBfQI7j0n4hIl+kCeqQW0asZ0ImI7JguoNfSSwFJRNSYmS6g13bMa1eMEhGRhfkCOicsEhHpMl1AbxYVCQBoGmm6qhMR+ZXpouIDo3sAAC5WVge5JkREocV0Ab1ZlOmqTEQUEKaLjrwXSkSkz3QBPYIRnYhIl+kCOhER6TNdQGcPnYhIn+kCOuM5EZE+0wV09tCJiPSZLqAznBMR6TNfQGdEJyLSZcKAzohORKTHdAGdiIj0MaATEYUJBnQiojDBgE5EFCYY0ImIwgQDOhFRmGBAJyIKEwzoRERhwlBAF5HxIpIuIhkiMtNNuREiUi0id/muiq4pxQ2jiYhqeQzoIhIJYD6ACQAGApgsIgNdlHsNwFpfV5KIiDwz0kMfCSBDKXVUKVUBYDmASTrlngTwBYBcH9aPiIgMMhLQuwLIsnmerR2zEpGuAG4HsMDdhURkqogkiUhSXl5efevq5IeMsw2+BhFRuDAS0PWyYTkOXr8N4M9KqWp3F1JKLVRKJSilEmJjY43W0aUTBaUNvgYRUbiIMlAmG0CczfNuAE46lEkAsFzLhNgBwK0iUqWU+sontXShsqrGn5cnIjIVIwF9F4A+ItITQA6A+wDcb1tAKdWz9rGIfATgW38HcyIisucxoCulqkRkOiyzVyIBfKiUOiAi07TzbsfN/Ym50YmI6hjpoUMptQrAKodjuoFcKfVww6tlDOehExHV4UpRIqIwYeqAzv45EVEdUwd0IiKqw4BORBQmTB3QeU+UiKiOqQM6ERHVYUAnIgoTDOhERGHC1AGdQ+hERHVMHdCJiKgOAzoRUZgwdUBnLhciojqmDuiBsGr/KWxIOxPsahAReWQo22Jj9r+fJgMAjs+dGOSaEBG5Z+oe+urU08GuAhFRyDB1QN+deS7YVSAiChmmDuhG5BReRHmV272riYjCgukDek2Nwsb0XN0ZLxVVNbhm7gbM+CwlCDVrmKrqGoyZtwGr9p8KdlWIyCRMH9A/2X4cj/xrF77Z5xz4qmpqAAAbDuUGuFb1U12j8MZ36SgsrbAeKy6rQlbBRTy3Yn8Qa0ZEZmL6gJ517iIA4ExRmVevzym8iPNllb6sUr1tSMvFPzZk4C9fHwhqPYjI3Ewf0KuqLb3wzIILXr3+mrkbMOHtrb6sUr1s/jkP3x2wzNYpr6wJWj2IyPxMH9A/3p4JAFiy40S9XnexohqV2odBTuFFn9fLqIc+3InPd2cH7f2JKHyE7cKid9YdRu+OLQEASicv44AX12BIXNtAV6vemN2AiIwKq4CeVVCKuHYtAABvrfvZY/mUrEJ/V8mtrIJSu+d6HzxEREaZfsjF1oWKKkPlQmVe+pPL9ngsIxKAihBRWAirgG50eKLfrDX+rYhBNQ4VFjB6E5H3wiqgExE1ZmEZ0DemhfZCIlfWHDiNNan2C6R4U5SIjArLgH6yKHjTEBvq7+szoJRCbnF5sKtCRCYTlgG9odaknsY76w7bHTt3ocJFad9b/MMx3PL2FgC8KUpExoVVQJ/wzlYczStxWnGppXTB/I0ZiJ+Z6PE605bsdpr2+F8Xi3+UUqiu8d24iAKw9fBZn12PiBoPQwFdRMaLSLqIZIjITJ3zk0Rkn4jsFZEkEbnW91U15sY3NmP2twftjlVU12B3ZgFeX5vu8/db/MMx9H5ulVc9eL3x8YzcYh/UiogaI48BXUQiAcwHMAHAQACTRWSgQ7H1AIYopYYCeBTAIl9XtKG2ZeT75bqfJ1l67meKvUsO5qiy2j7K86YoERllpIc+EkCGUuqoUqoCwHIAk2wLKKVKVF1C8pZA+C15fHdjRrCrgPiZiXh9bVqwq0FEIcpIQO8KIMvmebZ2zI6I3C4iaQASYemlOxGRqdqQTFJeXp439Q2aoovOKXZ3Z55D+hnvh0i8ueE5f+MRr9+PiMKbkYCuF3aceuBKqRVKqf4AbgPwst6FlFILlVIJSqmE2NjY+tU0ROQVl1tzsNz5z20NupaR4RS9oP/epgw8sTS5Qe9NROHHSEDPBhBn87wbgJOuCiultgDoLSIdGli3gHrso11YttNzCt4Rc9ZhzLyNfq3LtiN1s1yqq52j/rw16UjU2aGJiBo3I9kWdwHoIyI9AeQAuA/A/bYFROQyAEeUUkpEhgNoCsA/dyG9lO9hFsr6tFysT8vFG995ztIIADuO+q95tjdGi8uNJRwjIvLYQ1dKVQGYDmAtgEMAPlNKHRCRaSIyTSt2J4BUEdkLy4yYe5Xers1B9NG244bKnS0xtkJz8gc7dI/PW5NmaK57fRzNK/Hp9YgoPBnKh66UWgVglcOxBTaPXwPwmm+rFtpcfVy9t8n3Ny335xShV2wrnTooCJeSEpHGlCtFIyMaXxCr0VmN6sMFqkQUBkwZ0GdNHBDsKnhld2YB4mcm4kR+qefCDtztwDTjs714b1Pw58kTUXCZMqBHmbSHXruq9Mcj9c/VsmJPjstzXybnYN4a36c1ICJzMWVAD5aThcbT8rq7J+zN7eLsc+ZNCUxEgcGA7kZZpf3eo9OW7G7Q9fxx/9KXmR6JyNxMGdADFcL6v2C/9+i+7CKXZauqlV2eFXe98Iqqahw8ed5wPdzNZPkmxeUaLyJqZAxNWww1oTXD3WLV/lMepyx+tccSfF/6xpLet1NMswa/b2lFtedCAG5/70ecLSnHa3cORtLxc5g2tjdeW5OGHUfzkfjUmAbXg4iCz6QBPfQiekVVjdvzu44X4KLDEE5xmbFVoG7H4w1+X9lzohAAcP8HPwEAYqKjsPiHY4ZeS0TmYMohl1DkOJTtGGZzz/t+j1BPH2xH80rwzOcpqKp2/rAp9/ABREB+STmuenU90k4bHx4jCiYGdB/58Efn3m5ZZTXiZya63L7O6D3S5MxzusdPFpbh+RWp1ucXbYZflFK48Y3N+Hx3Ng54GK/39O2isdqYnodTRWVYuOVosKtCZIgpA3roDbg425CWiwItIdgb3zVsjvjH2zN1j1/3un3WxwEv1t3ETdxfl43x4CnngG57n7XvrNVe1WvbkbOIn5mII8w1QxQSzBnQTRDRn/jUPl95oFOu5JfUZZdctd851W766YYH4ZXaTd6dxwoafK1Q9NyK/cGuAlG9mDKgX9bROVFVqKmorsHVczcAcP0B5I/PpSU7MjHl4yS7Y3q7LX2RbD8MVFha/02uq7QbB+GaW6d2KOrLZP1Vukt/OoEvXAynEQWDKQP6dX3NtdvR6fNlAevFzvoqFesOnbG7Yepu/nytm9/aUu/3qtHew6ypGBrquRX78fTnKcGuBpGVKQO6GenlYzc6hzwQcovrPwunQps9k3xC/6YtEQUWA3qY8mY4J9/g5h61arfBW7LD89Z9ROR/DOhkdeUr64JdBSJqAAZ0IhPIzL+A4jLnm9tEthjQw1QgpnaOjG/n/zcJklDbx3Xs65twx3vbgl0NCnGmDeiXXxqje3zF/14d4Jo0PtU1CsfOXjC+1NVkvt13Eje+sTnY1XByODe0PmQo9JgyORcAJD41BvEzE52O9zbBHPVAmP3tQb9ct7K6Bn2et6wsDdfpip5SJRCFKtP20F2JCPSSzEbm2311+deruLkGUUgJq4AuAoRppzFoikorsWjrUZzXbshVVjOIE4WqsAroe1+8GRKuA7tBMuOzvXgl8RAGv/RdsKtCRB6YOqB3u6S53fM2zZsEPAlWuImfmYhZX9UlpTpTXGZ3nj9eotBl6oBuOzUvrl1z1wWpXmpXfsbPTERqDm8QEpmFyQN6XUT/6JGRAMJ35oVZlVZUoVJnx6RQVljKBTxkTqYO6LcM6ux0LCoyAosfSghCbcKf3lZ27ry+Ng0DX1yLRz/a5aca+ceFcmN7vRKFGlMH9FkTByIm2jKV3rZf3ikm2q7c1Ot64W93Dwlgzcwv+1yp7nGpx02K+RuPAAC2Hj7rdT0ulFt6+OeDvOzd3YdZ7vkyl+eIAsnUAT0yQtCuZVOn4wO7xODxMT2tz5+7dQDuurIbrurVPpDVM7VFW533SPU0YfFsPbI1Ju475TFH/MWKalz+l7Xo8/xqDH7pO5RXBS/d8PSle+yel1XW1SWYHzavr03DDX/bFLT3p9Bi6oBuy7bnGBEheH7iwCDWxvz08re7k3ziHBJeWYcXV6Z6LgzgiaXJuOf97fh230ks26mfftdxtWsw58CvOXDa+vjAySL0f6Fu/9ZDp4qDUSUAlm9Bx85eCNr7U2gxFNBFZLyIpItIhojM1Dn/axHZp/3ZJiIBG994bEwvAEBs62Yey0ZF8oZpQ7n6CR7Ulst/sj0T8TMT7Xqw7kxfugfPfrkfS3ZkYtHWo3bnXAX6YHPcAerJZXuwWmffVqJA8xjQRSQSwHwAEwAMBDBZRBy7v8cAjFVKDQbwMoCFvq6oKw+O7oHjcyeiVTPPaWk6trYfW//1qO7+qlZYysx33RN0nMlyqqh+48qzvkrFK4mHvKpXKGDiLAoFRnroIwFkKKWOKqUqACwHMMm2gFJqm1Kqdh+yHQC6+baavqFsRoEnXtEFT97YJ4i1MZ9fvLkFX+3V3zDZyL6l4SLn3MVgV4FIl5Fsi10BZNk8zwYwyk35xwCs1jshIlMBTAWA7t393zt+7c4rsD9HP9D06eQ6K2Ns62bI82KPzcYgJatQ9/iKPfqB3tGXydlYn5br1XurQCR5N+DdjRnBrgKRLiM9dL1hU93/WSJyAywB/c9655VSC5VSCUqphNjYWOO19NK9I7rjlduuqKufTVNG92qPVtH6n2d/+EVfv9fNrBoaUmd8lmLdi5SIfMtIQM8GEGfzvBuAk46FRGQwgEUAJiml8n1TPd/q0NoyxfEfk4dZAnqzKKT85Wa7Mj/OvBGTR8Yhocclwahi6DMY0dNPNyxlQLXJUvNyMRKFAiNDLrsA9BGRngByANwH4H7bAiLSHcCXAB5USv3s81r6yIxxfXFZbCv8cnAX67E2zZvYlWnVNAoi0uCeaLgqNhi4Vu51+syvl+Evf+90TG9RU0l5FS5WVBua5WSUNwneThToL8QiCiSPAV0pVSUi0wGsBRAJ4EOl1AERmaadXwDgRQDtAbyn/aerUkqF3Pr7ZlGRuDshzm2Z2hun3LyhYVannvZcyI2ii86LdfTG0G95awtyCi/i+NyJDXo/W6cKufKTzMnQPHSl1CqlVF+lVG+l1Bzt2AItmEMpNUUpdYlSaqj2J+SCuRFP3niZtcdeXWOZhjf9hsuw6qkx1jKje4XvxsihzjGcX6yoRk6h+xknSilsTMut1xDOzuPuV7AShaqwWSnqC0/f3M/6tX5AZ8sm1PeOiMNAmw2pl0+9Cpdx39J6e3tdw0fiHDvoU/+d5Lb8+kNnsCk9D498tAtvfe/8/kUXK+uVh+WpZXuw0sW0TcdhmjWpp3Tf01988fMl8zPtJtG+NO/OwegV29Lu2Mu3DcLkUd0R166FU/l1M8aisLQCQ2c7j/OSvrfXHcbvGzh7aM+Jc7i+X0cAwMa0XLdJv7YdOYvHPq4L+IdO2d+kXZN6CtOWJAOA4eGar1NO4usU/XsDjjtl1V77D+Oc23y+rBKqBmjToonTOW+9ve4w7h0Rhy5tuC9AY8YeOoB7RsQhId5+KCW6SSSGd6+b6fLbsb3szrdtUZcUrEOrusfcMcl/Hv5XXRrelGz9+fC1Dp+xX7npOOBSG3CDYfBL32HIbN9v6VdWaa688+R7DOgGPTthgJuenNg94pCM/5VWuM8VE/BFSPwgpxDAgO4DtpskNY2KwLoZY3G/TZ6Y5VNHB6FW5rQh7Yzb8ws2H0H8zEQs3HLU6dwr3x7EO+sO674uVFaZ2mKWRPI1BnQfuH1YV+vjqAjnH+lo5mE37ICHPUzd3fxb9MMxvKWdr89GHL7gzbvd8LdNKPZhLvWaEPzQosBiQPeBP4/vj9S/3oIHRnfHp1P009yM6dMhwLUyp6TMc27Pu5p9eCLffmGPY488VENdfce9n/1yn8tz/96e2dDqkMkxoPtARISgVbMovHLbFRgS11a3TG186dqWsxDcyS0ud7+M3kVkvu71jdbHP59x3nAiXDqvy3ZmuTyXxdWqjR4Dup90u8QSuAd3a2N3fPaky/GfqaPR1odT1szC0yIgwDK98Ffv/uDyvDLQ1775rS0ue/KenCpialwyLwb0Blj/9FjsePYm3XNNtLH0Edp0yNpA1CwqEqN6tcevhlwamEqGkGvmbjBU7kieu400jEXqRIcdhDy9atnOE8gqKMUHW5z3UjXC2zF7Iy9bd9D9jWKiWgzoDdA7thU6t4nWPXf/qO6468puePLGywAA08b2BgAM6mpZdfqX/7kcTaP44/cXxy3wbMfUN6bnOpV99sv9uHvBdq/fzzYuZ58zPvRhZChoyifuV8Ra68Cpk40eI4qftGwWhb/dPcS6AGlMn1gcnzvR+jwyQjD7V5fbveaGfv7PEd9YbT18FodOncfRvBI8YrNAyda50gqvr/91yknM1za+2HnMu1wwiftOoc/zqzzux1peVY2dxwpQWmF/ryFc7hOQ9xjQg8i2R3Vpm2h88BtT5jQzjQnvbMX/Ld/r8nx5VcNWWr6+Nh1rGpBlct7aNFRWK4/7sc5akYp73t+OZ/5rP+Ml/0IF1qRy85DGjAE9RGx65gZERfKfw1dcDT+42pLQV6Yt2Y3MfO9mm0RqlbbNDPmH/zh/AH2+OxsAkOSQFXJvViGmLUlGwQXvv2mQuTGCBFGv2LoUAbU3TR++Oj5ItQkvqR4WKPmTpyETW7YfPBHakuPFPxyzDtu426v13AX9RUlV1czp0lgxoAfRCJuEYLXjny85jKuT/yWfcL+YKVBqe+jLdp7APe97f4OWGi8G9CBr37Kp50L1dPPATj6/Zjg7dMp5IVJD5BaXe/U6zlKhhmJAD7LPp12FWRMHILpJpO75lk3rjndpE40OrTzvnfnotT19Vr/GwNeJu9wNk7iyO7OAybqowRjQg6xXbCtMGdNL91zrZvb7j/z7sVGc2ugHobB/7J3/3F7vWTZGVs1S48KAHoIWadMXm0ZF2K1AvKxjK0Nfy2v3RSVj5q5OC+r7O84nJ/IWA3oIulbLzPiHcX2tKxAXP2R8jvqALjGeC5GuQI9jC4DCUt+l0KXGjQE9BEU3icTxuRPxwOge1jXljlvk+VN3nX1UGwtf95Y9LTRqyKAJV4aSIwZ0k7k7Ic7v7/HanYP9/h6hyl16Wm98sNV5ZyW/0zoBP58pxosrU1HJeemNBgN6iKvNn167zZ3t3PXW0VF6L3HSsbXnmTG2rurdnonDfGR35jnkFZdjTuJB3dS8Gbkl+Gjbcb+899vrfsYn2zNx4GTwFllRYBmLCBQ0nzw2EjuPFaB1tPONzm+mX4ukzHP44+cpAIBrL+uAqEj7QeBxAzvh/QeuxIGT57Eq9RT+uemI3fkXfzkQ1TUKc1YdAgDrBh3Nm0SiooG5TchixJx1AIADJ89j6eP2+8vet3CH3943p9CSE4Zb0zUe7IaFuI6to/HLwfq50+M7tMRdV3azPl8yZRQ+emQkgLoe/Qe/SUBEhOCKbm3w5/H9na4RGSG4aUBHAMDSx0dh5RPXALDPJ0K+4eufqaurLdx8FPEzE1FdEz4fyKeLyjDuzc3cgMQD9tDDwJzbByH3vP3qxMSnxmDr4TyPrxWxzIU/Pnei3fEBXVpj1/HQWBIfLuo7g8bbLeUW/WDZpONihSWnTDh00JfvOoHDuSVYtjMLM8b1DXZ1QhZ76GHg16N64A8Ov+QDusRg6nW9ncrWbrjhyaKHRvikbuS9MfM2ei7khrudnyg8MaA3MvHtW9o9d9Vp5OIk3xOXP23vGE9ZELgu+tmScr8M14XDt4xAYEBvZMqq6lK79mjfAuMHdQlibRoXEe+HUfyhqLQS//rxmNsPhtnfHMQTnyYbul5+STkSXlmHF1am2h2Pn5mIa1/bgK/qkeOmrLIab33/M8qrjKciJo6hNzpREXW9xM3P3BDEmjROt76zNdhVsPrTFylYe+AM9pwoxPX9YnH7sK5Om11/+KNlPH6+9vzY2Qto27wJLrHJEvpNyknszynCPQmWG/RLfzqB6mqFKWN6WqdkZp+7iN//Zy9uG9bVUN0W/3AM76w/jHfWH8YW/p4aZqiHLiLjRSRdRDJEZKbO+f4isl1EykXkj76vJtlq0TQSfTu18lxQx53Du3kuRH4hAhSX+24lqtGRDVcd8DPajfSvU05ixmcp2H4k33rubEk5CnX2WL3hb5tw4xub7I49uWwPFm45isrqujf6T1IWxr21BZ/+dMJt3Y6dvYDEfXXb5v10NB/7sgvtNgmZtmQ33ll/2O11yMJjD11EImH5gB4HIBvALhH5Wil10KZYAYCnANzml1qSnYOzx3v9Wm+3uRvdqx12HK3b8uyay9rjx4x8LHlsFB5Y/JPX9aHQYfthk/DKOrtzySfOWdclnHORe8abtAk3vbEJNQqYONgyy+pebV6+7c17zqM3zsj/7pEAMpRSR5VSFQCWA5hkW0AplauU2gWAWYbC1CePjtI9zhSuxv2Yke+5kB/U9uQXbjmCE272O11/6AyW7MjUPXfHe9v8sgiKyx18y0hA7wrANsFFtnas3kRkqogkiUhSXp7nOdLkH/PvH47bDYxl2o5dNo2KQP/Ora3PbWdstNLytt96RWd88burfVhT8oV73t+OM+fL8P9WpeH+Ra6D8mdJ2Zj1lfe5X9iRDj4jAV1vrpVX/3RKqYVKqQSlVEJsLDdqCJaJg7vgrXuHeizXvb1+1sV37qt7rVLApKGWlaxv3jMUV/a4BADQq0NL3ddScNy9wLJHaWmF51kjpeWey+SeL2twncj3jMxyyQZgm+KvG4CT/qkOhZonbuiNCG3mw/X9OiLtdDFG92qP/+7OtpaZPWkQ/nRLf+s2eiv+92r0aN8Sw1/+Pih1JmcnHKZLVteoBs0XP1FQivNlVdifU2g9xj1Rg89IQN8FoI+I9ASQA+A+APf7tVYUMp65pb/N43546Ooe6BQTbT2mYMkH06ZF3UKkYd0tvfSPHx2J6KgI640uCh13LdiG/TlFDbrGL97c7KPaOONng3c8BnSlVJWITAewFkAkgA+VUgdEZJp2foGIdAaQBCAGQI2I/B7AQKUU83aGkcgIQZc2lnS+j17TE1sPn8Xll7reHWlsXw6rhao9Jwo9F3JDrzfOMfTgM7SwSCm1CsAqh2MLbB6fhmUohhqJG/p3dEroRY1HQ4N3VkGpofF8qh+uFCW/W/jglejSpjkW/XAUK/fy9gsZSDzGAXmvMJcL+d3Nl3fGFd3a4P6R3d2W+9vdQwJUIwoHRaWVqOFEdjsM6BQw8R6mMt51ZTdOdzSJYHegz5aUY8js7/D2up/tjr+3KQMpWYUoraiySynQWDCgU8B0ionG4TkTcP8oS0/9jmFdrQmdam344/U4Pnci9rwwDl3aROtdhghnSyx5aP6+IcPu+Lw16Zg0/0fMWpGKJ5YmY392w2bymA0DOgVUk8gIxGsLlsYP6ozOMfpB+5KWTbH+6bH47LdXWY/NnNAfg7q6nlVDoe/pz1Jw4KT/g2z2OctWdRe8yC9jZrwpSgH32LW90K9zDK7r0wGpbnakb9E0CiN7trM+nza2N6aN7Y34mYmBqCa54e1uSF8kZ+OnY8HJadMYsIdOARcZIRjbNxYigpv6WzaojneRZgCwbKf33K3OG1xT/RVcqEBecbnngh786b/7vH5tbe/ZFjey8A320CmohsS1xc7nbkI7mw0THK3+vzF2z/t3bo2008X+rlrYCsSQR331m7UG+1662VDZjWm5mPrvJD/XyJzYQ6eg6xgTXa887T1c9Oaba7lkyL1Qneh37kLdhhruJtE8/XmK3WYa7qxINr7tXThgQCfTefWOwZgxri9+Pao7Wmupe6de1yvItaKGSrf51nXGR9kc/5OU5blQGGFAJ9Np17IpnrqpD+bcfgVev3swAODuKz1nnnAcummsPM3Pvnfhdq+u+0UDe8NbDtftkZBTWDfO7rjPqS9UVdfgD//Zi4zcEp9fO5wxdQ0AAA1CSURBVJgY0MnUxg/qgow5E9CnU2vM//UwjIxv57LsgC4xWPSbBDw+pqf12OxJlweimiHFNvWxHm/vTyzb6X7/UE9s88OUV9ZtsnHoVP1z/O08Xrdd4vubjzjlb9+fU4QVe3Lw9Ocp1mPnyyrxl5WpyC3W/3Yw8e9b8YWHn12wMaCT6dWOv9/YvxM+m3aV0/l/PzYS62aMBQD8YmAnTL+hDwAgJjoKk4Z4tfkW+YFtR9w2IHtiu6G0nldXp2H6sj0er/POusP4eHsmRs5ZbzeeX+vAyfN2HwChiAGdws7GP16PB0bX5Y0Z0ycWl3VsZX3erInl1/7yS9sw8XYIWbLDux7+9a9vQtFF99sZl5R5XmBUZbP1Xr5DQFcmyQ3MaYsUdnp2aIlXbrsCdwzvhpho51/x6CaR+OJ3V6NPp1Zo3SwKT93UBzf174hJ8390ec07hnVFTPMmmDi4i3U7Nwq8v68/7HTs9PkyTF+ajAmDuhi6hlIK5VXu9011HLZ3lwPsp6P5qFYKV/fuYOj9/YkBncLWcG3nJD21e58CwIxxfQEA6a+MR9PICPR8dpVT+TcN7MFKwbMvuwhbD581VHbJjky8sPKA2zKOHXJ32/XV7sgVCvsDcMiFSNMsKhIigsgI++7Zn8cbW6X698nD/FEtMqDG4JDIrK/2ewzmDbl+sDGgEzn44c83YOUT11if12aH1NOzQ0v079waANC9nev0BRR8Fyuq3Y7Tu5se6U1AL7hQoXtz1Z845ELkoEub5ta9UwGgTfMmuuXSXh6PCBEcPVuC19ekY2AXZoIMZQNeXOP1a90Nuej5YMtRzFl1CABweM4EXCivQtsWrtNb+Ap76EReim4SiaZREejfOQaLHx6BplH2/51evm0QerRvoZujpE/HVpg09FK3m2xTYLmbyVJjcw/1423H3V6noqrGGswB4MmlezB09vcNrZ4h7KETubB86mik5jgnsrqubyy6ttXP475t5o1oGhWBDq2aAQAeHN3D7vzztw5A5zbR+J8hlwKwJMp6cPFOFAT4q3m4KfYwLfGgF4uTbFXbBPu/fH0AD10d77Ks4/DMmgOnAVg25aj9vfAX9tCJXBjdqz2mjHHOEfPJoyPx6h2DdV9zadvmbv/TThnT0xrMActc+OQXxjGxWJCkZBXiV+/+gBP5pQ5j6PZB2dUYen2GYhJeWedNFeuFPXSiAHJ1423zM9cjt7gcTSIjkFdcjgcW/2R3/sHRPfDvHZkALNPjfr98D77ae9J6/pbLO2H2pEHIKbyIO97b5r8GhKF92UW47vWNdscc43SVQ3bHsspqRDeJxBOfJluPFV2sxNOf7cW6Q7lu32/Kx0m4+fJOuCchrmEV18EeOlEA/PVXl+ORa+Jdnu8YE41BXdugX+fWuLZPB7x17xB8M/1a6/khcW3dXr97uxboFBONYXFt8dqdVyCuXXO35cm9zPxS65j6oq1HMfrV9Xbn+7+wBp8lZVmHUwBg/NtbPAbzpOMF2PxzLo6d9W7HJ0/YQycKAHdjrnpuH2bJHnl87kRkFZSi2yXNkX76PD7Sbsg5ftGvzWcjIrh3RHfszSrEsp2NK3WsLz3+SRLm3TkY94yIwyuJh3TLOO7adKrIc8rfr1NOorJaIcJPKScY0IlCXJw2v/35iQPx/MSBAICYaMtUynl3Dcbmn/Mw1WGs3yTrYELan77Yh6xzpT695ifbLcNmEX5ICQwwoBOZ0swJ/dGjfQvcNbyb7lhsNG+y+sQ/NmT45br+ygnHMXQiE2rZLApTxvRChIvv7k/f3BcdWzvPtrFdAfu0lsOGAs8fm3YA7KEThaXW0U3w/Yyx+M2HOzFr4gDsPVGI8qpqDO7WBhueHouK6hrEXdIC247kY/vRfLvXDugSg9uGXopXV6e5vP7YvrHIv1CO1Bz387uHxrXF3qxC3XMtm0biQoX7XObhyldb7DliD50oTLVp3gQrn7gGI+Lb4fHremH6jX0gIugV2wr9O8egZbMoLJs62u41Cx4YjqVTRuG3Y3vj+VsHYPLIOGTMmYCX/megtUxUhODjR0di4hWXOr4l5tw+CMdevRXP3NIPAHDtZR1waPZ4uzJd2zbH42N64k9a0rN+nVpbz/3u+t5u29S1bXjM3lm+yz83rNlDJ2rkOrRqhuv7xWLenYPthnAet9l4++FreuKuhDik5hShj7ZZiN5im1+PsqyMffjqeGQVlOK3Y3uhedO68fx/PTICw+Laom2LpigqrcSq/afw5r1DUVxWiUOnzuP2Yd1w/8juGDNvIy5tE42TNjNH1vx+DFo1i8KLKw/gpgEdcU9CHJIzz2FY90uw9XAeHvs4yec/G7MRIztxiMh4AO8AiASwSCk11+G8aOdvBVAK4GGlVLLThWwkJCSopCT+AxCZ1bkLFXhq+R68ec9QxDSPQklZFdq7WCU7f2MGYpo3cUqF4E72uVL8bkkyFj+cgI6t9VMt2NqbVYjW0VHo0KoZ2jRvgg1pZ5B+ugT3johDXnE5+nVujb9+cwD/+vE4AGDZ46Mx+YMddtcY06cDXrtzMK6euwEAMHlkd+zOLMDU63rjj5+n4I7hXfFlcg4iBHj/wQSsP3QGy3dlYVTPdvjpmPO2eY9e0xMf/ngMHVo1w9mScuvxPS+MwyUtvUvWJSK7lVIJuuc8BXQRiQTwM4BxALIB7AIwWSl10KbMrQCehCWgjwLwjlJqlLvrMqATUaCVlFfhH+sPY8bNfdEsKhJZBaXIv1CBoXFtUVOjINKwG5ZZBaUouFChuxBMKYWFW47i3hFxDcq82NCAfhWAl5RSt2jPn9Uq96pNmfcBbFJKLdOepwO4Xil1ytV1GdCJiOrPXUA3clO0KwDbEfxs7Vh9y0BEpopIkogk5eXlGXhrIiIyykhA1/v+4ditN1IGSqmFSqkEpVRCbGyskfoREZFBRgJ6NgDbpWjdAJz0ogwREfmRkYC+C0AfEekpIk0B3Afga4cyXwP4jViMBlDkbvyciIh8z+M8dKVUlYhMB7AWlmmLHyqlDojINO38AgCrYJnhkgHLtMVH/FdlIiLSY2hhkVJqFSxB2/bYApvHCsATvq0aERHVB5f+ExGFCQZ0IqIwYWjpv1/eWCQPQKaXL+8A4KwPqxNMbEtoCpe2hEs7ALalVg+llO6876AF9IYQkSRXK6XMhm0JTeHSlnBpB8C2GMEhFyKiMMGATkQUJswa0BcGuwI+xLaEpnBpS7i0A2BbPDLlGDoRETkzaw+diIgcMKATEYUJ0wV0ERkvIukikiEiM4NdHz0i8qGI5IpIqs2xdiLyvYgc1v6+xObcs1p70kXkFpvjV4rIfu3c36UhW6l41444EdkoIodE5ICI/J+J2xItIjtFJEVry1/N2hatDpEiskdEvjV5O45rddgrIkkmb0tbEfmviKRp/2euCnhblFKm+QNLcrAjAHoBaAogBcDAYNdLp57XARgOINXm2DwAM7XHMwG8pj0eqLWjGYCeWvsitXM7AVwFS7751QAmBLgdXQAM1x63hmUrwoEmbYsAaKU9bgLgJwCjzdgWrQ4zACwF8K1Zf7+0OhwH0MHhmFnb8jGAKdrjpgDaBrotAW2wD35gVwFYa/P8WQDPBrteLuoaD/uAng6gi/a4C4B0vTbAktXyKq1Mms3xyQDeD3KbVsKyt6yp2wKgBYBkWPa/NV1bYNlvYD2AG1EX0E3XDu19j8M5oJuuLQBiAByDNtEkWG0x25CLoa3uQlQnpeWI1/7uqB131aau2mPH40EhIvEAhsHSszVlW7Rhir0AcgF8r5Qya1veBvAnADU2x8zYDsCys9l3IrJbRKZqx8zYll4A8gD8SxsKWyQiLRHgtpgtoBva6s5kXLUpZNoqIq0AfAHg90qp8+6K6hwLmbYopaqVUkNh6eGOFJFBboqHZFtE5JcAcpVSu42+ROdY0Nth4xql1HAAEwA8ISLXuSkbym2JgmWY9Z9KqWEALsAyxOKKX9pitoBu5q3uzohIFwDQ/s7VjrtqU7b22PF4QIlIE1iC+adKqS+1w6ZsSy2lVCGATQDGw3xtuQbAr0TkOIDlAG4UkSUwXzsAAEqpk9rfuQBWABgJc7YlG0C29q0PAP4LS4APaFvMFtCNbIcXqr4G8JD2+CFYxqNrj98nIs1EpCeAPgB2al/PikVktHaX+zc2rwkI7X0XAziklHrT5pQZ2xIrIm21x80B/AJAGkzWFqXUs0qpbkqpeFh+/zcopR4wWzsAQERaikjr2scAbgaQChO2RSl1GkCWiPTTDt0E4CAC3ZZA3wTxwc2HW2GZbXEEwPPBro+LOi4DcApAJSyfuI8BaA/LjazD2t/tbMo/r7UnHTZ3tAEkwPILfgTAu3C44RKAdlwLy9e9fQD2an9uNWlbBgPYo7UlFcCL2nHTtcWmHtej7qao6doBy7hzivbnQO3/ZzO2RavDUABJ2u/YVwAuCXRbuPSfiChMmG3IhYiIXGBAJyIKEwzoRERhggGdiChMMKATEYUJBnQiojDBgE5EFCb+Pz0X/pj2x/PZAAAAAElFTkSuQmCC\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" - } - ], - "source": [ - "plt.plot(loss_value)" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [], - "source": [ - "file_path=\"./net_trainned6000\"\n", - "save_net(file_path,net)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.6" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/Apprentissage_initial_dataset.ipynb b/Apprentissage_initial_dataset.ipynb index 7e61bde..82634f4 100644 --- a/Apprentissage_initial_dataset.ipynb +++ b/Apprentissage_initial_dataset.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -32,7 +32,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -75,11 +75,15 @@ " data_vt_raw = f.readlines()\n", " data_vt = [int(d.rstrip('\\r\\n')) for d in data_vt_raw]\n", " \n", + " #print(\"[load_training_fragment] vt: {}:{}\".format(data_vt[2],data_vt[3]))\n", + " \n", " # Construct vt\n", " vt = np.zeros((int(data_vt[0]/4)+1,int(data_vt[1]/4)+1))\n", + " #print(\"[load_training_fragment] point 1 placé en : {}:{}\".format(data_vt[2]/4,data_vt[3]/4))\n", " vt[int(data_vt[2]/4),int(data_vt[3]/4)] = 1\n", " vt = np.float32(vt)\n", " vt = torch.from_numpy(vt.reshape(1,1,int(data_vt[0]/4)+1,int(data_vt[1]/4)+1))\n", + " #print(\"[load_training_fragment] taille de la vt: HxW {}x{}\".format(vt.shape[2],vt.shape[3]))\n", " \n", " return(frag,vt)\n", "\n", @@ -97,11 +101,11 @@ "# Trouvez les coordonnées de la valeur maximale dans une carte de corrélation\n", "# x,y=show_coordonnee(carte de corrélation)\n", "def show_coordonnee(position_pred):\n", - " map_corre=position_pred.squeeze().detach().numpy()\n", + " map_corre=position_pred.squeeze().detach().cpu().numpy()\n", " h,w=map_corre.shape\n", " max_value=map_corre.max()\n", " coordonnee=np.where(map_corre==max_value)\n", - " return coordonnee[0].mean()/h,coordonnee[1].mean()/w\n", + " return coordonnee[0].mean(),coordonnee[1].mean()\n", "\n", "# Filtrer les patchs en fonction du nombre de pixels noirs dans le patch\n", "# Si seuls les pixels non noirs sont plus grands qu'une certaine proportion(seuillage), revenez à True, sinon False\n", @@ -129,7 +133,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ @@ -321,7 +325,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -366,15 +370,15 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Fresque 0, fragment 2824/3000 (94.1%)\n", - "Temps par fragment: 0.759\n" + "Net sauvegardés dans ./trained_net/net_trainned_02-02_21-44_0002\n", + "Poids sauvegardés dans ./trained_net/save_weights_02-02_21-44_2\n" ] } ], @@ -388,7 +392,7 @@ " using_cuda=True\n", " \n", " # Variable des données\n", - " base_dir = './training_data_small/'\n", + " base_dir = './training_data_maybe-good/'\n", " fresque_filename = base_dir+'fresque_small{}.ppm'\n", " fresque_filename_wild = base_dir+'fresque_small*.ppm'\n", " fragment_filename = base_dir+'fragments/fresque{}/frag_dev_{:05}.ppm'\n", @@ -432,15 +436,15 @@ " if using_cuda:\n", " fresque_tensor=fresque_tensor.cuda()\n", " \n", - " # Detection des fragments d'entrainement\n", + " # Recherche des fragments d'entrainement\n", " fragments_paths = glob(fragments_filename_wild.format(fresque_id))\n", " N_fragments = len(fragments_paths)\n", " for fragment_id,fragment_path in enumerate(fragments_paths):\n", - " clear_output(wait=True)\n", + " #clear_output(wait=True)\n", " print(\"Fresque {}, fragment {}/{} ({:.3}%)\".format(fresque_id,fragment_id,N_fragments,(fragment_id/N_fragments)*100))\n", " print(\"Temps par fragment: {:.3}\".format(time()-time_old))\n", " time_old = time()\n", - " # Tous les 100 cycles, enregistrez le changement de poids\n", + " # De temps en temps, enregistrez les nouveau poids\n", " if fragment_id%50==0:\n", " w_values.append(net.conv1.weight.data.cpu().numpy())\n", " \n", @@ -457,6 +461,8 @@ " b,c,h,w=vt.shape\n", " # Utilisez la fonction de coût pour calculer l'erreur\n", " err_=loss_func(vt,frag_pred)\n", + " #print(\"[MAIN] position choisie dans la carte de correlation: {}\".format(show_coordonnee(frag_pred)))\n", + " #print(\"[MAIN] Valeur de la loss: {}\".format(err_.tolist()))\n", " # Utilisez l'optimiseur pour ajuster le poids de Conv1\n", " optimizer.zero_grad()\n", " err_.backward(retain_graph=True)\n", @@ -469,12 +475,12 @@ " \n", " # Sauvegarder le réseau\n", " save_dir = './trained_net/'\n", - " extension = 'from-random_full-dataset-small'\n", - " net_filename = save_dir + \"net_trainned_{}_{}\".format(extension,datetime.now().strftime(\"%m-%d_%H-%M\"))\n", + " expe_id = 2\n", + " net_filename = save_dir + \"net_trainned_{}_{:04}\".format(datetime.now().strftime(\"%m-%d_%H-%M\"),expe_id)\n", " save_net(net_filename,net)\n", " \n", " # Sauvegarder les poids\n", - " poids_filename = save_dir + \"save_weights_{}_{}\".format(extension,datetime.now().strftime(\"%m-%d_%H-%M\"))\n", + " poids_filename = save_dir + \"save_weights_{}_{:04}\".format(datetime.now().strftime(\"%m-%d_%H-%M\"),expe_id)\n", " with open(poids_filename,'wb') as f:\n", " pickle.dump(w_values,f)\n", " \n", diff --git a/Benchmark.ipynb b/Benchmark.ipynb index ba73a3e..e528b98 100755 --- a/Benchmark.ipynb +++ b/Benchmark.ipynb @@ -503,15 +503,16 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Fresque 0, fragment 34/200 (17.0%)\n", - "Temps par fragment: 12.3. ETA = 2.05e+03s\n" + "Fresque 5, fragment 255/256 (99.6%)\n", + "Temps par fragment: 11.6. ETA = 11.6s\n", + "Sauvegardé dans ./results_bench/results_bench_f5_02-04_17-53_0003\n" ] } ], @@ -525,15 +526,16 @@ " #fresque_id = 2\n", "\n", " # Variable des données\n", - " base_dir = './training_data_small/'\n", + " base_dir = './fragments_complets/'\n", " fresque_filename = base_dir+'fresque_small{}.ppm'\n", " fresque_filename_wild = base_dir+'fresque_small*.ppm'\n", " fragment_filename = base_dir+'fragments/fresque{}/frag_bench_{:05}.ppm'\n", " fragments_filename_wild = base_dir+'fragments/fresque{}/frag_bench_*.ppm'\n", " vt_filename = base_dir+'fragments/fresque{}/vt/frag_bench_{:05}_vt.txt'\n", - " net_filename = \"./trained_net/net_trainned_from-random_full-dataset-small_01-29_18-14_0001\"\n", + " net_filename = \"./trained_net/net_trainned_02-03_01-33_0002\"\n", " \n", - " expe_id = int(net_filename.split(\"_\")[-1]) # ID de l'expérience, à ajouter à tout les fichiers écrits pour identifier les résultats d'une même expérience.\n", + " #expe_id = int(net_filename.split(\"_\")[-1]) # ID de l'expérience, à ajouter à tout les fichiers écrits pour identifier les résultats d'une même expérience.\n", + " expe_id = 3\n", " date = datetime.now().strftime(\"%m-%d_%H-%M\")\n", " results_filename = './results_bench/results_bench_f{}_{}_{:04}'.format(fresque_id,date,expe_id)\n", "\n", diff --git a/Benchmark_MB.ipynb b/Benchmark_MB.ipynb new file mode 100644 index 0000000..9459bfd --- /dev/null +++ b/Benchmark_MB.ipynb @@ -0,0 +1,723 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "#Tous les codes sont basés sur l'environnement suivant\n", + "#python 3.7\n", + "#opencv 3.1.0\n", + "#pytorch 1.4.0\n", + "\n", + "import torch\n", + "from torch.autograd import Variable\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "import cv2\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import random\n", + "import math\n", + "import pickle\n", + "import random\n", + "from PIL import Image\n", + "import sys\n", + "from glob import glob\n", + "from IPython.display import clear_output\n", + "from datetime import datetime\n", + "import json\n", + "from time import time\n", + "from PIL import Image\n", + "from torchvision import transforms" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# Les fonctions dans ce bloc ne sont pas utilisées par le réseau, mais certaines fonctions d'outils\n", + "\n", + "# Les fonctions de ce bloc se trouvent dans le programme d'apprentissage \n", + "# “Apprentissage_MSELoss_avec_GPU“\n", + "# et les commentaires détaillés se trouvent dans le programme d'apprentissage\n", + "\n", + "def tensor_imshow(im_tensor,cannel):\n", + " b,c,h,w=im_tensor.shape\n", + " if c==1:\n", + " plt.imshow(im_tensor.squeeze().detach().numpy())\n", + " else:\n", + " plt.imshow(im_tensor.squeeze().detach().numpy()[cannel,:])\n", + " \n", + "def get_training_fragment(frag_size,im):\n", + " h,w,c=im.shape\n", + " n=random.randint(0,int(h/frag_size)-1)\n", + " m=random.randint(0,int(w/frag_size)-1)\n", + " \n", + " shape=frag_size/4\n", + " vt_h=math.ceil((h+1)/shape)\n", + " vt_w=math.ceil((w+1)/shape)\n", + " vt=np.zeros([vt_h,vt_w])\n", + " vt_h_po=round((vt_h-1)*(n*frag_size/(h-1)+(n+1)*frag_size/(h-1))/2)\n", + " vt_w_po=round((vt_w-1)*(m*frag_size/(w-1)+(m+1)*frag_size/(w-1))/2)\n", + " vt[vt_h_po,vt_w_po]=1\n", + " vt = np.float32(vt)\n", + " vt=torch.from_numpy(vt.reshape(1,1,vt_h,vt_w))\n", + " \n", + " return im[n*frag_size:(n+1)*frag_size,m*frag_size:(m+1)*frag_size,:],vt\n", + "\n", + "def write_result_in_file(result,file_name):\n", + " n=0\n", + " with open(file_name,'w') as file:\n", + " for i in range(len(result)):\n", + " while n=2 and m>=2:# Si n=m=1,Notre réseau n'a plus besoin de plus de couches pour agréger les cartes de corrélation\n", + " self.shift2=nn.Conv2d(n*m,n*m,kernel_size=3,stride=1,padding=1)\n", + " self.shift2.weight=kernel_shift_ini(n,m)\n", + " self.add2 = nn.Conv2d(n*m,int(n/2)*int(m/2),kernel_size=1,stride=1,padding=0)\n", + " self.add2.weight=kernel_add_ini(n,m)\n", + " \n", + " n=int(n/2)\n", + " m=int(m/2)\n", + " if n>=2 and m>=2:\n", + " self.shift3=nn.Conv2d(n*m,n*m,kernel_size=3,stride=1,padding=1)\n", + " self.shift3.weight=kernel_shift_ini(n,m)\n", + " self.add3 = nn.Conv2d(n*m,int(n/2)*int(m/2),kernel_size=1,stride=1,padding=0)\n", + " self.add3.weight=kernel_add_ini(n,m)\n", + " \n", + " def get_descripteur(self,img,using_cuda):\n", + " # Utilisez Conv1 pour calculer le descripteur,\n", + " descripteur_img=self.Relu(self.conv1(img))\n", + " b,c,h,w=descripteur_img.shape\n", + " couche_constante = 0.5 * torch.ones([b, 1, h, w])\n", + " if using_cuda:\n", + " couche_constante=couche_constante.cuda()\n", + " # Ajouter une couche constante pour éviter la division par 0 lors de la normalisation\n", + " descripteur_img = torch.cat((descripteur_img,couche_constante),1)\n", + " # la normalisation\n", + " descripteur_img_norm = F.normalize(descripteur_img) #/ torch.norm(descripteur_img)\n", + " return descripteur_img_norm\n", + " \n", + " def forward(self,img,frag,using_cuda):\n", + " psize=4\n", + " # Utilisez Conv1 pour calculer le descripteur,\n", + " descripteur_input2=self.get_descripteur(frag,using_cuda)\n", + " descripteur_input1=self.get_descripteur(img,using_cuda)\n", + " \n", + " b,c,h,w=frag.shape\n", + " n=int(h/psize)\n", + " m=int(w/psize)\n", + " \n", + " db,dc,dh,dw = descripteur_input1.shape\n", + " \n", + " #######################################\n", + " # Calculer la carte de corrélation par convolution pour les n*m patchs plus petit.\n", + " for i in range(n):\n", + " for j in range(m):\n", + " if i==0 and j==0:\n", + " ##HAD TO CHANGE THIS LINE BECAUSE OF CONVOLUTION DIMENSION FOR BATCHES\n", + " map_corre=F.conv2d(descripteur_input1.view(1,db*dc,dh,dw),get_patch(descripteur_input2,psize,i,j),padding=2,groups=db)#.detach()\n", + " map_corre=map_corre.view(db,1,map_corre.size(2),map_corre.size(3))#.detach()\n", + " else:\n", + " a=F.conv2d(descripteur_input1.view(1,db*dc,dh,dw),get_patch(descripteur_input2,psize,i,j),padding=2, groups=db)#.detach()\n", + " a=a.view(db,1,a.size(2),a.size(3))#.detach()\n", + " map_corre=torch.cat((map_corre,a),1)#.detach()\n", + " \n", + " ########################################\n", + " # Étape de polymérisation\n", + " map_corre=self.maxpooling(map_corre)\n", + " map_corre=self.shift1(map_corre)\n", + " map_corre=self.add1(map_corre)\n", + " \n", + " #########################################\n", + " # Répétez l'étape d'agrégation jusqu'à obtenir le graphique de corrélation du patch d'entrée\n", + " n=int(n/2)\n", + " m=int(m/2)\n", + " if n>=2 and m>=2:\n", + " map_corre=self.maxpooling(map_corre)\n", + " map_corre=self.shift2(map_corre)\n", + " map_corre=self.add2(map_corre)\n", + " \n", + " \n", + " n=int(n/2)\n", + " m=int(m/2)\n", + " if n>=2 and m>=2:\n", + " map_corre=self.maxpooling(map_corre)\n", + " map_corre=self.shift3(map_corre)\n", + " map_corre=self.add3(map_corre)\n", + " \n", + " \n", + " #b,c,h,w=map_corre.shape\n", + " # Normalisation de la division par maximum\n", + " map_corre=map_corre/map_corre.max()#.detach()\n", + " # Normalisation SoftMax\n", + " #map_corre=(F.softmax(map_corre.reshape(1,1,h*w,1),dim=2)).reshape(b,c,h,w)\n", + " return map_corre" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "# Les fonctions de ce bloc sont utilisées pour appliquer le réseau à des fragments (pas à des patchs carrés)\n", + "\n", + "\n", + "# Cette fonction permet de sélectionner un ensemble de patchs carrés à partir d'un fragment\n", + "# Le paramètre “frag_size” fait ici référence à la taille du patch d'entrée carré (16 * 16)\n", + "# Le paramètre “seuillage” limite la proportion de pixels non noirs dans chaque patch\n", + "# Le paramètre “limite” peut limiter le nombre de correctifs trouvés dans chaque fragment\n", + "def get_patch_list(frag,frag_size,limite,seuillage):\n", + " n=0\n", + " m=0\n", + " h,w,c=frag.shape\n", + " patch_list=[]\n", + " position_list=[]\n", + " for i in range(4):\n", + " if len(patch_list)>limite and limite!=0:\n", + " break\n", + " for j in range(4):\n", + " if len(patch_list)>limite and limite!=0:\n", + " break\n", + " n_offset=i*4 # n offset\n", + " m_offset=j*4 # m offset\n", + " n=0\n", + " while n+frag_size+n_offset0:\n", + " rot_frag=math.atan(tan_rot)*(180/math.pi)\n", + " else:\n", + " rot_frag=math.atan(tan_rot)*(180/math.pi)+180\n", + " rot_frag=-rot_frag\n", + " if rot_frag>0:\n", + " rot_frag-=360\n", + " return centre[0][0],centre[1][0],rot_frag\n", + "\n", + "# Vérifiez les résultats de Ransac en avec des changements de distance euclidienne\n", + "def test_frag(inline,frag,fres):\n", + " itera=10\n", + " frag_inline=[]\n", + " fres_inline=[]\n", + " # Metter les coordonnées du point inline dans \"frag_inline[]\",et \"fres_inline[]\"\n", + " for i in range(np.size(inline,0)):\n", + " if inline[i]==1:\n", + " frag_inline.append([frag[i][0],frag[i][1]])\n", + " fres_inline.append([fres[i][0],fres[i][1]])\n", + " p=[]\n", + " \n", + " # Faites une boucle dix fois, \n", + " # sélectionnez à chaque fois deux paires correspondantes inline \n", + " # calculer le changement de leur distance euclidienne\n", + " for i in range(itera):\n", + " point_test=selectionner_points(2,np.size(frag_inline,0))\n", + " diff_x_frag=frag_inline[point_test[1]][0]-frag_inline[point_test[0]][0]\n", + " diff_y_frag=frag_inline[point_test[1]][1]-frag_inline[point_test[0]][1]\n", + " diff_frag=math.sqrt(math.pow(diff_x_frag,2)+math.pow(diff_y_frag,2))\n", + " \n", + " diff_x_fres=fres_inline[point_test[1]][0]-fres_inline[point_test[0]][0]\n", + " diff_y_fres=fres_inline[point_test[1]][1]-fres_inline[point_test[0]][1]\n", + " diff_fres=math.sqrt(math.pow(diff_x_fres,2)+math.pow(diff_y_fres,2))\n", + " if diff_frag !=0:\n", + " fsf=diff_fres/diff_frag\n", + " p.append([fsf])\n", + " result=np.mean(p)\n", + " return result\n", + "\n", + "def frag_match(frag,img,position):\n", + " \n", + " frag_size=frag.size\n", + " centre_frag=creer_point(frag_size[0]/2,frag_size[1]/2)\n", + " \n", + " retained_matches = []\n", + " frag=[]\n", + " fres=[]\n", + " \n", + " for i in range(len(position)):\n", + " frag.append([float(position[i][0]),float(position[i][1])])\n", + " fres.append([float(position[i][2]),float(position[i][3])])\n", + " \n", + " if np.size(frag)>0:\n", + " # Calculer la matrice de transformation affine à l'aide de la méthode Ransac\n", + " h,inline=cv2.estimateAffinePartial2D(np.array(frag),np.array(fres))\n", + " # Si “h” n'est pas sous la forme de matrice 2 * 3, la matrice de transformation affine n'est pas trouvée\n", + " if np.size(h)!=6:\n", + " return ([-1])\n", + " else:\n", + " x,y,rot=position_rotation(h,centre_frag)\n", + " pourcenttage=sum(inline)/np.size(frag,0)\n", + " # Le nombre de points inline doit être supérieur à un certain nombre\n", + " if sum(inline)>3:\n", + " p=test_frag(inline,frag,fres)\n", + " # La distance euclidienne entre les points correspondants ne doit pas trop changer, \n", + " # sinon cela prouve que le résultat de Ransac est incorrect\n", + " # ici,le changement de la distance euclidienne sont entre 0.7 et 1.3\n", + " if abs(p-1)<0.3:\n", + " # Ce n'est qu'alors que Ransac renvoie le résultat correct\n", + " return([round(x),round(y),round(rot,3)])\n", + " else:\n", + " return ([-2])\n", + " else:\n", + " return ([-3])\n", + " else:\n", + " return ([-4]) " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Fresque 1, fragment 97/100 (97.0%)\n", + "Temps par fragment: 12.0. ETA = 35.9s\n" + ] + } + ], + "source": [ + "if __name__==\"__main__\":\n", + " \n", + " # Variable du réseau\n", + " frag_size=16\n", + " using_cuda=True\n", + " for fresque_id in range(6):\n", + " #fresque_id = 2\n", + "\n", + " # Variable des données\n", + " base_dir = './training_data_random_shift_color/'\n", + " fresque_filename = base_dir+'fresque{}.ppm'\n", + " fresque_filename_wild = base_dir+'fresque*.ppm'\n", + " fragment_filename = base_dir+'fragments/fresque{}/frag_bench_{:05}.ppm'\n", + " fragments_filename_wild = base_dir+'fragments/fresque{}/frag_bench_*.ppm'\n", + " vt_filename = base_dir+'fragments/fresque{}/vt/frag_bench_{:05}_vt.txt'\n", + " net_filename = \"./trained_net/net_trainned_MB4_02-10_20-49_0003\"\n", + " \n", + " expe_id = int(net_filename.split(\"_\")[-1]) # ID de l'expérience, à ajouter à tout les fichiers écrits pour identifier les résultats d'une même expérience.\n", + " date = datetime.now().strftime(\"%m-%d_%H-%M\")\n", + " results_filename = './results_bench/results_bench_f{}_{}_{:04}'.format(fresque_id,date,expe_id)\n", + "\n", + " # Chargement du réseau\n", + " net=load_net(net_filename)\n", + "\n", + " # Charge la fresque en mémoire\n", + " img=Image.open(fresque_filename.format(fresque_id))\n", + " \n", + " #N_fragments = 20\n", + " N_fragments = len(glob(fragments_filename_wild.format(fresque_id)))\n", + " N_fragments = 100\n", + " #print(fragments_filename_wild.format(fresque_id))\n", + " print(N_fragments)\n", + "\n", + " # Crée les tableau de résultats\n", + " distances, matched, positions, verite_terrain = [],[],[],[]\n", + " tailles = []\n", + "\n", + " time_old = time()\n", + " # Parcour tout les fragments de bench de cette fresque\n", + " for fragment_id in range(N_fragments):\n", + " clear_output(wait=True)\n", + " print(\"Fresque {}, fragment {}/{} ({:.3}%)\".format(fresque_id,fragment_id,N_fragments,(fragment_id/N_fragments*100)))\n", + " delta = time()-time_old\n", + " print(\"Temps par fragment: {:.3}. ETA = {:.3}s\".format(delta,(N_fragments-fragment_id)*delta))\n", + " time_old = time()\n", + " frag = Image.open(fragment_filename.format(fresque_id,fragment_id))\n", + "\n", + " # Faites pivoter les pièces de 20 degrés à chaque fois pour correspondre, répétez 18 fois\n", + " for i in [0,17]:\n", + " rotation=20*i\n", + " #rotation=0\n", + " #rotation_base=0\n", + " score_list,positions_patchs=run_net_v3(net,img,frag,frag_size,60,0.7,using_cuda,rotation)\n", + " frag_position=frag_match(frag,img,positions_patchs)\n", + " # Lorsque Ransac obtient le bon résultat, sortez de la boucle\n", + " if len(frag_position)==3:\n", + " rotation_base=i*20\n", + " break\n", + " # Si Ransac trouve une solution, la variable renvoyé est une liste de deux positions et une rotation\n", + " if len(frag_position)==3:\n", + " \n", + " # MATCHED\n", + " matched.append(1)\n", + "\n", + " # POSITION\n", + " frag_position[2]=rotation_base-360-frag_position[2]\n", + " if frag_position[2]>0:\n", + " frag_position[2]=frag_position[2]-360\n", + " positions.append([frag_position[0],frag_position[1],round(frag_position[2],3)])\n", + "\n", + " # VERITE TERRAIN\n", + " with open(vt_filename.format(fresque_id,fragment_id), 'r') as f:\n", + " data_vt = f.read().splitlines()\n", + " verite_terrain.append([int(data_vt[2]),int(data_vt[3]),frag.size[0],frag.size[1]])\n", + "\n", + " # DISTANCE\n", + " distances.append(np.linalg.norm([float(data_vt[3])-float(frag_position[0]),float(data_vt[2])-float(frag_position[1])]))\n", + " else:\n", + " matched.append(0)\n", + " distances.append(-1)\n", + " positions.append([])\n", + " verite_terrain.append([])\n", + "\n", + " del frag\n", + "\n", + " meta = {'date':date,'base_dir':base_dir,'fresque_id':fresque_id,'fresque_taille':img.size,'N_fragments': N_fragments,'expe_id': expe_id}\n", + " res = {'meta':meta, 'matched':matched,'distances':distances,'positions':positions,'vt':verite_terrain}\n", + "\n", + " with open(results_filename,'w') as f:\n", + " f.write(json.dumps(res))\n", + "\n", + " print(\"Sauvegardé dans {}\".format(results_filename))" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Sauvegarder dans results_f0_02-08_23-15\n" + ] + } + ], + "source": [ + "date = datetime.now().strftime(\"%m-%d_%H-%M\")\n", + "meta = {'date':date,'base_dir':base_dir,'fresque_id':fresque_id,'fresque_taille':img.size,'N_fragments': N_fragments}\n", + "res = {'meta':meta, 'matched':matched,'distances':distances,'positions':positions,'vt':verite_terrain}\n", + "\n", + "with open('results_bench/results_bench_from-random_full-dataset-small_MB9_f{}_{}'.format(fresque_id,date),'w') as f:\n", + " f.write(json.dumps(res))\n", + "\n", + "print(\"Sauvegarder dans {}\".format('results_f{}_{}'.format(fresque_id,date)))" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0]\n" + ] + } + ], + "source": [ + "print(matched)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "80" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Clear GPU memory \n", + "import gc\n", + "torch.cuda.empty_cache()\n", + "gc.collect()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.5" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/display_bench.ipynb b/display_bench.ipynb index c1d8392..e2c159c 100644 --- a/display_bench.ipynb +++ b/display_bench.ipynb @@ -9,7 +9,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -26,13 +26,13 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 31, "metadata": {}, "outputs": [], "source": [ "def carte(matched,positions,vt,meta):\n", " \n", - " fresque = cv2.imread(meta['base_dir']+'fresque_small{}.ppm'.format(meta['fresque_id']))\n", + " fresque = cv2.imread(meta['base_dir']+'fresque{}.ppm'.format(meta['fresque_id']))\n", " \n", " fig,ax = plt.subplots()\n", " ax.imshow(fresque)\n", @@ -41,21 +41,39 @@ " #ax.arrow(vt[i][0],vt[i][1],p[i][0]-vt[i][0],p[i][1]-vt[i][1])\n", " ax.plot([vt[i][0],p[i][0]],[vt[i][1],p[i][1]],marker='D',color='red')\n", " ax.plot([vt[i][0]],[vt[i][1]],marker='D',color='green')\n", - " fig.show()" + " fig.show()\n", + " \n", + "def correlation(matched, position, vt, d, meta):\n", + " \n", + " fig,ax = plt.subplots()\n", + " for i in range(len(matched)):\n", + " if matched[i] == 1:\n", + " frag = cv2.imread('./training_data_small/fragments/fresque{}/frag_bench_{:05}.ppm'.format(meta['fresque_id'],i))\n", + " ax.scatter(frag.shape[1],frag.shape[0],s=d[i]*2,alpha=0.5)\n", + " ax.set_xlabel(\"Width\")\n", + " ax.set_ylabel('Height')\n", + " ax.set_title(\"Erreur de placement en fonction de la hauteur et la largeur des fragments.\")\n", + " fig.show()\n", + " \n", + "def distance_vecteur(matched,p,v):\n", + " \n", + " fig, ax = plt.subplots()\n", + " for i in range(len(matched)):\n", + " if matched[i] == 1:\n", + " vecteur = (v[i][0]-p[i][0],v[i][1]-p[i][1])\n", + " #print('{}:{} {}:{}'.format(v[i][0], v[i][1], vecteur[0], vecteur[1]))\n", + " ax.scatter(vecteur[0],vecteur[1],s = (vecteur[0]**2+vecteur[1]**2)**0.5)\n", + " \n", + " ax.set_xlabel(\"W\")\n", + " ax.set_ylabel('H')\n", + " ax.set_title(\"Vecteur d'erreur de placement.\")\n", + " fig.show()\n", + " " ] }, { "cell_type": "code", - "execution_count": 16, - "metadata": {}, - "outputs": [], - "source": [ - "results_filename = './results_bench/results_bench_f2_01-31_16-13_0001'" - ] - }, - { - "cell_type": "code", - "execution_count": 19, + "execution_count": 38, "metadata": {}, "outputs": [ { @@ -838,7 +856,796 @@ { "data": { "text/html": [ - "" + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": [ + "/* Put everything inside the global mpl namespace */\n", + "window.mpl = {};\n", + "\n", + "\n", + "mpl.get_websocket_type = function() {\n", + " if (typeof(WebSocket) !== 'undefined') {\n", + " return WebSocket;\n", + " } else if (typeof(MozWebSocket) !== 'undefined') {\n", + " return MozWebSocket;\n", + " } else {\n", + " alert('Your browser does not have WebSocket support. ' +\n", + " 'Please try Chrome, Safari or Firefox ≥ 6. ' +\n", + " 'Firefox 4 and 5 are also supported but you ' +\n", + " 'have to enable WebSockets in about:config.');\n", + " };\n", + "}\n", + "\n", + "mpl.figure = function(figure_id, websocket, ondownload, parent_element) {\n", + " this.id = figure_id;\n", + "\n", + " this.ws = websocket;\n", + "\n", + " this.supports_binary = (this.ws.binaryType != undefined);\n", + "\n", + " if (!this.supports_binary) {\n", + " var warnings = document.getElementById(\"mpl-warnings\");\n", + " if (warnings) {\n", + " warnings.style.display = 'block';\n", + " warnings.textContent = (\n", + " \"This browser does not support binary websocket messages. \" +\n", + " \"Performance may be slow.\");\n", + " }\n", + " }\n", + "\n", + " this.imageObj = new Image();\n", + "\n", + " this.context = undefined;\n", + " this.message = undefined;\n", + " this.canvas = undefined;\n", + " this.rubberband_canvas = undefined;\n", + " this.rubberband_context = undefined;\n", + " this.format_dropdown = undefined;\n", + "\n", + " this.image_mode = 'full';\n", + "\n", + " this.root = $('
');\n", + " this._root_extra_style(this.root)\n", + " this.root.attr('style', 'display: inline-block');\n", + "\n", + " $(parent_element).append(this.root);\n", + "\n", + " this._init_header(this);\n", + " this._init_canvas(this);\n", + " this._init_toolbar(this);\n", + "\n", + " var fig = this;\n", + "\n", + " this.waiting = false;\n", + "\n", + " this.ws.onopen = function () {\n", + " fig.send_message(\"supports_binary\", {value: fig.supports_binary});\n", + " fig.send_message(\"send_image_mode\", {});\n", + " if (mpl.ratio != 1) {\n", + " fig.send_message(\"set_dpi_ratio\", {'dpi_ratio': mpl.ratio});\n", + " }\n", + " fig.send_message(\"refresh\", {});\n", + " }\n", + "\n", + " this.imageObj.onload = function() {\n", + " if (fig.image_mode == 'full') {\n", + " // Full images could contain transparency (where diff images\n", + " // almost always do), so we need to clear the canvas so that\n", + " // there is no ghosting.\n", + " fig.context.clearRect(0, 0, fig.canvas.width, fig.canvas.height);\n", + " }\n", + " fig.context.drawImage(fig.imageObj, 0, 0);\n", + " };\n", + "\n", + " this.imageObj.onunload = function() {\n", + " fig.ws.close();\n", + " }\n", + "\n", + " this.ws.onmessage = this._make_on_message_function(this);\n", + "\n", + " this.ondownload = ondownload;\n", + "}\n", + "\n", + "mpl.figure.prototype._init_header = function() {\n", + " var titlebar = $(\n", + " '
');\n", + " var titletext = $(\n", + " '
');\n", + " titlebar.append(titletext)\n", + " this.root.append(titlebar);\n", + " this.header = titletext[0];\n", + "}\n", + "\n", + "\n", + "\n", + "mpl.figure.prototype._canvas_extra_style = function(canvas_div) {\n", + "\n", + "}\n", + "\n", + "\n", + "mpl.figure.prototype._root_extra_style = function(canvas_div) {\n", + "\n", + "}\n", + "\n", + "mpl.figure.prototype._init_canvas = function() {\n", + " var fig = this;\n", + "\n", + " var canvas_div = $('
');\n", + "\n", + " canvas_div.attr('style', 'position: relative; clear: both; outline: 0');\n", + "\n", + " function canvas_keyboard_event(event) {\n", + " return fig.key_event(event, event['data']);\n", + " }\n", + "\n", + " canvas_div.keydown('key_press', canvas_keyboard_event);\n", + " canvas_div.keyup('key_release', canvas_keyboard_event);\n", + " this.canvas_div = canvas_div\n", + " this._canvas_extra_style(canvas_div)\n", + " this.root.append(canvas_div);\n", + "\n", + " var canvas = $('');\n", + " canvas.addClass('mpl-canvas');\n", + " canvas.attr('style', \"left: 0; top: 0; z-index: 0; outline: 0\")\n", + "\n", + " this.canvas = canvas[0];\n", + " this.context = canvas[0].getContext(\"2d\");\n", + "\n", + " var backingStore = this.context.backingStorePixelRatio ||\n", + "\tthis.context.webkitBackingStorePixelRatio ||\n", + "\tthis.context.mozBackingStorePixelRatio ||\n", + "\tthis.context.msBackingStorePixelRatio ||\n", + "\tthis.context.oBackingStorePixelRatio ||\n", + "\tthis.context.backingStorePixelRatio || 1;\n", + "\n", + " mpl.ratio = (window.devicePixelRatio || 1) / backingStore;\n", + "\n", + " var rubberband = $('');\n", + " rubberband.attr('style', \"position: absolute; left: 0; top: 0; z-index: 1;\")\n", + "\n", + " var pass_mouse_events = true;\n", + "\n", + " canvas_div.resizable({\n", + " start: function(event, ui) {\n", + " pass_mouse_events = false;\n", + " },\n", + " resize: function(event, ui) {\n", + " fig.request_resize(ui.size.width, ui.size.height);\n", + " },\n", + " stop: function(event, ui) {\n", + " pass_mouse_events = true;\n", + " fig.request_resize(ui.size.width, ui.size.height);\n", + " },\n", + " });\n", + "\n", + " function mouse_event_fn(event) {\n", + " if (pass_mouse_events)\n", + " return fig.mouse_event(event, event['data']);\n", + " }\n", + "\n", + " rubberband.mousedown('button_press', mouse_event_fn);\n", + " rubberband.mouseup('button_release', mouse_event_fn);\n", + " // Throttle sequential mouse events to 1 every 20ms.\n", + " rubberband.mousemove('motion_notify', mouse_event_fn);\n", + "\n", + " rubberband.mouseenter('figure_enter', mouse_event_fn);\n", + " rubberband.mouseleave('figure_leave', mouse_event_fn);\n", + "\n", + " canvas_div.on(\"wheel\", function (event) {\n", + " event = event.originalEvent;\n", + " event['data'] = 'scroll'\n", + " if (event.deltaY < 0) {\n", + " event.step = 1;\n", + " } else {\n", + " event.step = -1;\n", + " }\n", + " mouse_event_fn(event);\n", + " });\n", + "\n", + " canvas_div.append(canvas);\n", + " canvas_div.append(rubberband);\n", + "\n", + " this.rubberband = rubberband;\n", + " this.rubberband_canvas = rubberband[0];\n", + " this.rubberband_context = rubberband[0].getContext(\"2d\");\n", + " this.rubberband_context.strokeStyle = \"#000000\";\n", + "\n", + " this._resize_canvas = function(width, height) {\n", + " // Keep the size of the canvas, canvas container, and rubber band\n", + " // canvas in synch.\n", + " canvas_div.css('width', width)\n", + " canvas_div.css('height', height)\n", + "\n", + " canvas.attr('width', width * mpl.ratio);\n", + " canvas.attr('height', height * mpl.ratio);\n", + " canvas.attr('style', 'width: ' + width + 'px; height: ' + height + 'px;');\n", + "\n", + " rubberband.attr('width', width);\n", + " rubberband.attr('height', height);\n", + " }\n", + "\n", + " // Set the figure to an initial 600x600px, this will subsequently be updated\n", + " // upon first draw.\n", + " this._resize_canvas(600, 600);\n", + "\n", + " // Disable right mouse context menu.\n", + " $(this.rubberband_canvas).bind(\"contextmenu\",function(e){\n", + " return false;\n", + " });\n", + "\n", + " function set_focus () {\n", + " canvas.focus();\n", + " canvas_div.focus();\n", + " }\n", + "\n", + " window.setTimeout(set_focus, 100);\n", + "}\n", + "\n", + "mpl.figure.prototype._init_toolbar = function() {\n", + " var fig = this;\n", + "\n", + " var nav_element = $('
');\n", + " nav_element.attr('style', 'width: 100%');\n", + " this.root.append(nav_element);\n", + "\n", + " // Define a callback function for later on.\n", + " function toolbar_event(event) {\n", + " return fig.toolbar_button_onclick(event['data']);\n", + " }\n", + " function toolbar_mouse_event(event) {\n", + " return fig.toolbar_button_onmouseover(event['data']);\n", + " }\n", + "\n", + " for(var toolbar_ind in mpl.toolbar_items) {\n", + " var name = mpl.toolbar_items[toolbar_ind][0];\n", + " var tooltip = mpl.toolbar_items[toolbar_ind][1];\n", + " var image = mpl.toolbar_items[toolbar_ind][2];\n", + " var method_name = mpl.toolbar_items[toolbar_ind][3];\n", + "\n", + " if (!name) {\n", + " // put a spacer in here.\n", + " continue;\n", + " }\n", + " var button = $('