Fresque-SETI/Apprentissage_MB.ipynb

621 lines
40 KiB
Text
Raw Normal View History

2021-02-11 09:00:27 +01:00
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"#Tous les codes sont basés sur l'environnement suivant\n",
"#python 3.7\n",
"#opencv 3.1.0\n",
"#pytorch 1.4.0\n",
"\n",
"import torch\n",
"from torch.autograd import Variable\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"import cv2\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import random\n",
"import math\n",
"import pickle\n",
"import random\n",
"from PIL import Image\n",
2021-03-12 12:12:26 +01:00
"import sys"
2021-02-11 09:00:27 +01:00
]
},
{
"cell_type": "code",
2021-03-12 12:12:26 +01:00
"execution_count": 2,
2021-02-11 09:00:27 +01:00
"metadata": {},
"outputs": [],
"source": [
"#Les fonctions dans ce bloc ne sont pas utilisées par le réseau, mais certaines fonctions d'outils\n",
"\n",
"def tensor_imshow(im_tensor,cannel):\n",
" b,c,h,w=im_tensor.shape\n",
" if c==1:\n",
" plt.imshow(im_tensor.squeeze().detach().numpy())\n",
" else:\n",
" plt.imshow(im_tensor.squeeze().detach().numpy()[cannel,:])\n",
"\n",
"# Obtenez des données d'entraînement\n",
"# frag,vt=get_training_fragment(frag_size,image)\n",
"# frag est un patch carrée de taille (frag_size*frag_size) a partir du image(Son emplacement est aléatoire)\n",
"# vt est la vérité terrain de la forme Dirac.\n",
"def get_training_fragment(frag_size,im):\n",
" h,w,c=im.shape\n",
" n=random.randint(0,int(h/frag_size)-1)\n",
" m=random.randint(0,int(w/frag_size)-1) \n",
" shape=frag_size/4\n",
" vt_h=math.ceil((h+1)/shape)\n",
" vt_w=math.ceil((w+1)/shape)\n",
" vt=np.zeros([vt_h,vt_w])\n",
" vt_h_po=round((vt_h-1)*(n*frag_size/(h-1)+(n+1)*frag_size/(h-1))/2)\n",
" vt_w_po=round((vt_w-1)*(m*frag_size/(w-1)+(m+1)*frag_size/(w-1))/2)\n",
" vt[vt_h_po,vt_w_po]=1\n",
" vt = np.float32(vt)\n",
" vt=torch.from_numpy(vt.reshape(1,1,vt_h,vt_w))\n",
" \n",
" return im[n*frag_size:(n+1)*frag_size,m*frag_size:(m+1)*frag_size,:],vt\n",
"\n",
"# Cette fonction convertit l'image en variable de type Tensor.\n",
"# Toutes les données de calcul du réseau sont de type Tensor\n",
"# Img.shape=[Height,Width,Channel]\n",
"# Tensor.shape=[Batch,Channel,Height,Width]\n",
"def img2tensor(im):\n",
" im=np.array(im,dtype=\"float32\")\n",
" tensor_cv = torch.from_numpy(np.transpose(im, (2, 0, 1)))\n",
" im_tensor=tensor_cv.unsqueeze(0)\n",
" return tensor_cv\n",
"\n",
"# Trouvez les coordonnées de la valeur maximale dans une carte de corrélation\n",
"# x,y=show_coordonnee(carte de corrélation)\n",
"def show_coordonnee(position_pred):\n",
" map_corre=position_pred.squeeze().detach().numpy()\n",
" h,w=map_corre.shape\n",
" max_value=map_corre.max()\n",
" coordonnee=np.where(map_corre==max_value)\n",
" return coordonnee[0].mean()/h,coordonnee[1].mean()/w\n",
"\n",
"# Filtrer les patchs en fonction du nombre de pixels noirs dans le patch\n",
"# Si seuls les pixels non noirs sont plus grands qu'une certaine proportion(seuillage), revenez à True, sinon False\n",
"def test_fragment32_32(frag,seuillage):\n",
" a=frag[:,:,0]+frag[:,:,1]+frag[:,:,2]\n",
" mask = (a == 0)\n",
" arr_new = a[mask]\n",
" if arr_new.size/a.size<=(1-seuillage):\n",
" return True\n",
" else:\n",
" return False\n",
" \n",
"# Ces deux fonctions permettent de sauvegarder le réseau dans un fichier\n",
"# ou de load le réseau stocké à partir d'un fichier\n",
"def save_net(file_path,net):\n",
" pkl_file = open(file_path, 'wb')\n",
" pickle.dump(net,pkl_file)\n",
" pkl_file.close()\n",
"def load_net(file_path): \n",
" pkl_file = open(file_path, 'rb')\n",
" net= pickle.load(pkl_file)\n",
" pkl_file.close()\n",
" return net"
]
},
{
"cell_type": "code",
2021-03-12 12:12:26 +01:00
"execution_count": 3,
2021-02-11 09:00:27 +01:00
"metadata": {},
"outputs": [],
"source": [
"# Créer un poids de type DeepMatch comme valeur initiale de Conv1 (non obligatoire)\n",
"def ini():\n",
" kernel=torch.zeros([8,3,3,3])\n",
" array_0=np.array([[1,2,1],[0,0,0],[-1,-2,-1]],dtype='float32')\n",
" array_1=np.array([[2,1,0],[1,0,-1],[0,-1,-2]],dtype='float32')\n",
" array_2=np.array([[1,0,-1],[2,0,-2],[1,0,-1]],dtype='float32')\n",
" array_3=np.array([[0,-1,-2],[1,0,-1],[2,1,0]],dtype='float32')\n",
" array_4=np.array([[-1,-2,-1],[0,0,0],[1,2,1]],dtype='float32')\n",
" array_5=np.array([[-2,-1,0],[-1,0,1],[0,1,2]],dtype='float32')\n",
" array_6=np.array([[-1,0,1],[-2,0,2],[-1,0,1]],dtype='float32')\n",
" array_7=np.array([[0,1,2],[-1,0,1],[-2,-1,0]],dtype='float32')\n",
" for i in range(3):\n",
" kernel[0,i,:]=torch.from_numpy(array_0)\n",
" kernel[1,i,:]=torch.from_numpy(array_1)\n",
" kernel[2,i,:]=torch.from_numpy(array_2)\n",
" kernel[3,i,:]=torch.from_numpy(array_3)\n",
" kernel[4,i,:]=torch.from_numpy(array_4)\n",
" kernel[5,i,:]=torch.from_numpy(array_5)\n",
" kernel[6,i,:]=torch.from_numpy(array_6)\n",
" kernel[7,i,:]=torch.from_numpy(array_7)\n",
" return torch.nn.Parameter(kernel,requires_grad=True) \n",
"\n",
"# Calculer le poids initial de la couche convolutive add\n",
"# n, m signifie qu'il y a n * m sous-patches dans le patch d'entrée\n",
"# Par exemple, le patch d'entrée est 16 * 16, pour les patchs 4 * 4 de la première couche, n = 4, m = 4\n",
"# pour les patchs 8 * 8 de la deuxième couche, n = 2, m = 2\n",
"def kernel_add_ini(n,m):\n",
" input_canal=int(n*m)\n",
" output_canal=int(n/2)*int(m/2)\n",
" for i in range(int(n/2)):\n",
" for j in range(int(m/2)):\n",
" kernel_add=np.zeros([1,input_canal],dtype='float32')\n",
" kernel_add[0,i*2*m+j*2]=1\n",
" kernel_add[0,i*2*m+j*2+1]=1\n",
" kernel_add[0,(i*2+1)*m+j*2]=1\n",
" kernel_add[0,(i*2+1)*m+j*2+1]=1\n",
" if i==0 and j==0:\n",
" add=torch.from_numpy(kernel_add.reshape(1,input_canal,1,1))\n",
" else:\n",
" add_=torch.from_numpy(kernel_add.reshape(1,input_canal,1,1))\n",
" add=torch.cat((add,add_),0)\n",
" return torch.nn.Parameter(add,requires_grad=False) \n",
"\n",
"# Calculer le poids initial de la couche convolutive shift\n",
"# shift+add Peut réaliser l'étape de l'agrégation\n",
"# Voir ci-dessus pour les paramètres n et m. \n",
"# Pour des étapes plus détaillées, veuillez consulter mon rapport de stage\n",
"def kernel_shift_ini(n,m):\n",
" input_canal=int(n*m)\n",
" output_canal=int(n*m)\n",
" \n",
" kernel_shift=torch.zeros([output_canal,input_canal,3,3])\n",
" \n",
" array_0=np.array([[1,0,0],[0,0,0],[0,0,0]],dtype='float32')\n",
" array_1=np.array([[0,0,1],[0,0,0],[0,0,0]],dtype='float32')\n",
" array_2=np.array([[0,0,0],[0,0,0],[1,0,0]],dtype='float32')\n",
" array_3=np.array([[0,0,0],[0,0,0],[0,0,1]],dtype='float32')\n",
" \n",
" kernel_shift_0=torch.from_numpy(array_0)\n",
" kernel_shift_1=torch.from_numpy(array_1)\n",
" kernel_shift_2=torch.from_numpy(array_2)\n",
" kernel_shift_3=torch.from_numpy(array_3)\n",
" \n",
" \n",
" for i in range(n):\n",
" for j in range(m):\n",
" if i==0 and j==0:\n",
" kernel_shift[0,0,:]=kernel_shift_0\n",
" else:\n",
" if i%2==0 and j%2==0:\n",
" kernel_shift[i*m+j,i*m+j,:]=kernel_shift_0\n",
" if i%2==0 and j%2==1:\n",
" kernel_shift[i*m+j,i*m+j,:]=kernel_shift_1\n",
" if i%2==1 and j%2==0:\n",
" kernel_shift[i*m+j,i*m+j,:]=kernel_shift_2\n",
" if i%2==1 and j%2==1:\n",
" kernel_shift[i*m+j,i*m+j,:]=kernel_shift_3\n",
" \n",
" return torch.nn.Parameter(kernel_shift,requires_grad=False) \n",
"\n",
"# Trouvez le petit patch(4 * 4) dans la n ème ligne et la m ème colonne du patch d'entrée\n",
"# Ceci est utilisé pour calculer la convolution et obtenir la carte de corrélation\n",
"def get_patch(fragment,psize,n,m):\n",
" return fragment[:,:,n*psize:(n+1)*psize,m*psize:(m+1)*psize]\n",
"\n",
"###################################################################################################################\n",
"class Net(nn.Module):\n",
" def __init__(self,frag_size,psize):\n",
" super(Net, self).__init__()\n",
" \n",
" h_fr=frag_size\n",
" w_fr=frag_size\n",
" \n",
" n=int(h_fr/psize) # n*m patches dans le patch d'entrée\n",
" m=int(w_fr/psize)\n",
" \n",
" self.conv1 = nn.Conv2d(3,8,kernel_size=3,stride=1,padding=1)\n",
" # Si vous souhaitez initialiser Conv1 avec les poids de DeepMatch, exécutez la ligne suivante\n",
" # self.conv1.weight=ini()\n",
" self.Relu = nn.ReLU(inplace=True)\n",
" self.maxpooling=nn.MaxPool2d(3,stride=2, padding=1)\n",
" \n",
" self.shift1=nn.Conv2d(n*m,n*m,kernel_size=3,stride=1,padding=1)\n",
" self.shift1.weight=kernel_shift_ini(n,m)\n",
" self.add1 = nn.Conv2d(n*m,int(n/2)*int(m/2),kernel_size=1,stride=1,padding=0)\n",
" self.add1.weight=kernel_add_ini(n,m)\n",
" \n",
" n=int(n/2)\n",
" m=int(m/2)\n",
" if n>=2 and m>=2:# Si n=m=1Notre réseau n'a plus besoin de plus de couches pour agréger les cartes de corrélation\n",
" self.shift2=nn.Conv2d(n*m,n*m,kernel_size=3,stride=1,padding=1)\n",
" self.shift2.weight=kernel_shift_ini(n,m)\n",
" self.add2 = nn.Conv2d(n*m,int(n/2)*int(m/2),kernel_size=1,stride=1,padding=0)\n",
" self.add2.weight=kernel_add_ini(n,m)\n",
" \n",
" n=int(n/2)\n",
" m=int(m/2)\n",
" if n>=2 and m>=2:\n",
" self.shift3=nn.Conv2d(n*m,n*m,kernel_size=3,stride=1,padding=1)\n",
" self.shift3.weight=kernel_shift_ini(n,m)\n",
" self.add3 = nn.Conv2d(n*m,int(n/2)*int(m/2),kernel_size=1,stride=1,padding=0)\n",
" self.add3.weight=kernel_add_ini(n,m)\n",
" \n",
" def get_descripteur(self,img,using_cuda):\n",
" # Utilisez Conv1 pour calculer le descripteur,\n",
" descripteur_img=self.Relu(self.conv1(img))\n",
" b,c,h,w=descripteur_img.shape\n",
" couche_constante = 0.5 * torch.ones([b, 1, h, w])\n",
" if using_cuda:\n",
" couche_constante=couche_constante.cuda()\n",
" # Ajouter une couche constante pour éviter la division par 0 lors de la normalisation\n",
" descripteur_img = torch.cat((descripteur_img,couche_constante),1)\n",
" # la normalisation\n",
2021-03-12 12:12:26 +01:00
" descripteur_img_norm = F.normalize(descripteur_img) #/torch.norm(descripteur_img,dim=1, keepdim = True)\n",
2021-02-11 09:00:27 +01:00
" return descripteur_img_norm\n",
" \n",
" def forward(self,img,frag,using_cuda):\n",
" psize=4\n",
" # Utilisez Conv1 pour calculer le descripteur,\n",
" descripteur_input2=self.get_descripteur(frag,using_cuda)\n",
" descripteur_input1=self.get_descripteur(img,using_cuda)\n",
" \n",
" b,c,h,w=frag.shape\n",
" n=int(h/psize)\n",
" m=int(w/psize)\n",
" \n",
" db,dc,dh,dw = descripteur_input1.shape\n",
" \n",
" #######################################\n",
" # Calculer la carte de corrélation par convolution pour les n*m patchs plus petit.\n",
" for i in range(n):\n",
" for j in range(m):\n",
" if i==0 and j==0:\n",
" ##HAD TO CHANGE THIS LINE BECAUSE OF CONVOLUTION DIMENSION FOR BATCHES\n",
" map_corre=F.conv2d(descripteur_input1.view(1,db*dc,dh,dw),get_patch(descripteur_input2,psize,i,j),padding=2,groups=db)\n",
"\n",
" map_corre=map_corre.view(db,1,map_corre.size(2),map_corre.size(3))\n",
" else:\n",
" a=F.conv2d(descripteur_input1.view(1,db*dc,dh,dw),get_patch(descripteur_input2,psize,i,j),padding=2, groups=db)\n",
" a=a.view(db,1,a.size(2),a.size(3))\n",
" map_corre=torch.cat((map_corre,a),1)\n",
" \n",
" ########################################\n",
" # Étape de polymérisation\n",
" map_corre=self.maxpooling(map_corre)\n",
" map_corre=self.shift1(map_corre)\n",
" map_corre=self.add1(map_corre)\n",
" \n",
" #########################################\n",
" # Répétez l'étape d'agrégation jusqu'à obtenir le graphique de corrélation du patch d'entrée\n",
" n=int(n/2)\n",
" m=int(m/2)\n",
" if n>=2 and m>=2:\n",
" map_corre=self.maxpooling(map_corre)\n",
" map_corre=self.shift2(map_corre)\n",
" map_corre=self.add2(map_corre)\n",
" \n",
" \n",
" n=int(n/2)\n",
" m=int(m/2)\n",
" if n>=2 and m>=2:\n",
" map_corre=self.maxpooling(map_corre)\n",
" map_corre=self.shift3(map_corre)\n",
" map_corre=self.add3(map_corre)\n",
" \n",
" \n",
" #b,c,h,w=map_corre.shape\n",
" # Normalisation de la division par maximum\n",
" map_corre=map_corre/map_corre.max()\n",
" # Normalisation SoftMax\n",
" #map_corre=(F.softmax(map_corre.reshape(1,1,h*w,1),dim=2)).reshape(b,c,h,w)\n",
" return map_corre"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Dataset and Dataloader"
]
},
{
"cell_type": "code",
2021-03-12 12:12:26 +01:00
"execution_count": 4,
2021-02-11 09:00:27 +01:00
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import re\n",
"from PIL import Image\n",
"from torchvision import transforms\n",
"from torch.utils.data import Dataset, DataLoader\n",
"from time import time\n",
"\n",
"\n",
"class FragmentDataset(Dataset):\n",
" def __init__(\n",
" self,\n",
" fragments_path, \n",
" train, \n",
" frags_transform=transforms.ToTensor(),\n",
" fresques_transform=None,\n",
" vts_transform=None,\n",
" ):\n",
" \"\"\"\n",
" Parameters\n",
" ----------\n",
" fragments_path: str\n",
" Path to root flder with fragments folders.\n",
" train: boolean\n",
" True for train set (__dev__) False for test (__bench__)\n",
" frags_transform: torchvision.transform\n",
" Tranform to apply to all fragment images. Default: ToTensor()\n",
" fresques_transform: torchvision.transform\n",
" Transform to apply to all fresque images. frags_transform if None.\n",
" vts_transform: transform to apply to all vts images. Default: ToTensor().\n",
" \"\"\"\n",
" self.base_path = fragments_path\n",
" self.frags_transform = frags_transform\n",
" self.fresques_transform = fresques_transform if fresques_transform else frags_transform\n",
" self.fragments_list = []\n",
" self.vts_transform = vts_transform\n",
" \n",
" # To separate between train (dev) and test fragments(bench)\n",
" self.match_expr = \"_dev_\" if train else \"_bench_\"\n",
" \n",
" fragments_path = os.path.join(self.base_path, \"fragments\")\n",
" for fresque_dir in os.listdir(fragments_path):\n",
" current_path = os.path.join(fragments_path, fresque_dir)\n",
" \n",
" if \"fresque\" in current_path: \n",
" # Avoids looking at extra files in the dirs.\n",
" \n",
" # Get path to current fresque (ie: ..path/fresque0.ppm).\n",
" fresque_name = current_path.split(\"/\")[-1] + \".ppm\"\n",
" full_fresque_path = os.path.join(self.base_path, fresque_name) \n",
" \n",
" # Get path to every fragment for that fresque (ie: ..path/fresque0/frag_bench_000.ppm)\n",
" all_fragments_fresque = sorted(os.listdir(current_path))\n",
" \n",
" #Get path to every vt for that fresque (ie: ..path/fresque0/vt/frag_bench_000.ppm))\n",
" vts_path = os.path.join(current_path, \"vt\")\n",
" all_vts_fresque = sorted(os.listdir(vts_path))\n",
" \n",
" # Keep fragments that belong in that set (Train | Test) \n",
" # group them with the full fresque path (tuple)\n",
" all_fragments_fresque = [\n",
" (os.path.join(current_path, frag_path), full_fresque_path, os.path.join(vts_path, vt_path))\n",
" for frag_path, vt_path in zip(all_fragments_fresque, all_vts_fresque)\n",
" if re.search(self.match_expr, frag_path) and re.search(self.match_expr, vt_path)\n",
" ]\n",
" \n",
" self.fragments_list.extend(all_fragments_fresque)\n",
" \n",
" def __len__(self):\n",
" return len(self.fragments_list)\n",
" \n",
" def __getitem__(self, idx):\n",
" # Loads the fragment and the full fresque as a tensor.\n",
" fragment = Image.open(self.fragments_list[idx][0])\n",
" fresque = Image.open(self.fragments_list[idx][1])\n",
" \n",
" with open(self.fragments_list[idx][2],'r') as f:\n",
" data_vt_raw = f.readlines()\n",
" data_vt = [int(d.rstrip('\\r\\n')) for d in data_vt_raw]\n",
" \n",
" # Construct vt\n",
2021-03-12 12:12:26 +01:00
" div = 8\n",
" vt = np.zeros((int(data_vt[0]/div)+1,int(data_vt[1]/div)+1))\n",
" vt[int(data_vt[2]/div),int(data_vt[3]/div)] = 1\n",
2021-02-11 09:00:27 +01:00
" vt = np.float32(vt)\n",
2021-03-12 12:12:26 +01:00
" vt = torch.from_numpy(vt.reshape(1,vt.shape[0],vt.shape[0]))\n",
" #vt = torch.from_numpy(vt.reshape(1,int(data_vt[0]/div)+1,int(data_vt[1]/div)+1))\n",
2021-02-11 09:00:27 +01:00
" \n",
" return self.frags_transform(fragment), self.fresques_transform(fresque), vt"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Usage"
]
},
{
"cell_type": "code",
2021-03-12 12:12:26 +01:00
"execution_count": 5,
2021-02-11 09:00:27 +01:00
"metadata": {},
"outputs": [],
"source": [
"fresques_tnsf = transforms.Compose([\n",
" transforms.Resize((1000, 1000)),\n",
" transforms.ToTensor()\n",
"])\n",
"\n",
2021-03-12 12:12:26 +01:00
"train = FragmentDataset(fragments_path=\"training_data_32\", train=True, fresques_transform=fresques_tnsf)\n",
"test = FragmentDataset(fragments_path=\"training_data_32\", train=False, fresques_transform=fresques_tnsf)\n",
"bs = 2\n",
"train_loader = DataLoader(train, batch_size=bs, num_workers=6, pin_memory=True, shuffle = True)\n",
"test_loader = DataLoader(test, batch_size=bs, num_workers=6, pin_memory=True, shuffle = True)"
2021-02-11 09:00:27 +01:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Train"
]
},
{
"cell_type": "code",
2021-03-12 12:12:26 +01:00
"execution_count": 6,
2021-02-11 09:00:27 +01:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2021-03-12 12:12:26 +01:00
"Temps par batch: 0.291\n",
"Temps par batch: 6.76\n",
"Temps par batch: 5.59\n",
"Temps par batch: 5.6\n"
]
},
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-6-7e6697633d7e>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 32\u001b[0m \u001b[0;32mdel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfragments\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 33\u001b[0m \u001b[0;32mdel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfresques\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 34\u001b[0;31m \u001b[0mvts\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mvts\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcuda\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 35\u001b[0m \u001b[0mcost\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mloss_func\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpreds\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvts\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 36\u001b[0m \u001b[0mcost\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
2021-02-11 09:00:27 +01:00
]
}
],
"source": [
2021-03-12 12:12:26 +01:00
"frag_size = 32\n",
2021-02-11 09:00:27 +01:00
"psize = 4\n",
2021-03-12 12:12:26 +01:00
"num_epochs = 6\n",
2021-02-11 09:00:27 +01:00
"\n",
"net = Net(frag_size, psize).cuda()\n",
" \n",
"optimizer = torch.optim.Adam(net.parameters(), lr=0.001)\n",
"loss_func = torch.nn.MSELoss()\n",
"\n",
2021-03-12 12:12:26 +01:00
"\n",
2021-02-11 09:00:27 +01:00
"\n",
"loss_value = []\n",
"para_value = []\n",
"\n",
"time_old = time()\n",
"\n",
2021-03-12 12:12:26 +01:00
"torch.cuda.empty_cache\n",
"optimizer.zero_grad()\n",
"\n",
2021-02-11 09:00:27 +01:00
"for epoch in range(num_epochs):\n",
" for fragments, fresques, vts in train_loader:\n",
" \n",
2021-03-12 12:12:26 +01:00
" print(\"Temps par batch: {:.3}\".format(time()-time_old))\n",
2021-02-11 09:00:27 +01:00
" time_old = time()\n",
"\n",
" fragments = fragments.cuda()\n",
" fresques = fresques.cuda()\n",
" \n",
" preds = net(fresques, fragments, True) \n",
" optimizer.zero_grad()\n",
" \n",
" del(fragments)\n",
" del(fresques)\n",
" vts = vts.cuda()\n",
2021-03-12 12:12:26 +01:00
" cost = loss_func(preds, vts)\n",
2021-02-11 09:00:27 +01:00
" cost.backward()\n",
" del(vts)\n",
" optimizer.step()\n",
"\n",
" loss_value.append(cost.item())\n",
" torch.cuda.empty_cache \n",
" print('Done with epoch ', epoch)\n",
2021-03-12 12:12:26 +01:00
" # sauvegarder le réseau dans le fichier \"net_trainned6000\"\n",
"file_path=\"./net_trainned_from-random_full-dataset-small_02-10_07-30_0107\"\n",
"save_net(file_path,net)"
2021-02-11 09:00:27 +01:00
]
},
{
"cell_type": "code",
2021-03-12 12:12:26 +01:00
"execution_count": 7,
2021-02-11 09:00:27 +01:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"13500"
]
},
2021-03-12 12:12:26 +01:00
"execution_count": 7,
2021-02-11 09:00:27 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(loss_value)"
]
},
{
"cell_type": "code",
2021-03-12 12:12:26 +01:00
"execution_count": 8,
2021-02-11 09:00:27 +01:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2021-03-12 12:12:26 +01:00
"[<matplotlib.lines.Line2D at 0x7f411022f2e0>]"
2021-02-11 09:00:27 +01:00
]
},
2021-03-12 12:12:26 +01:00
"execution_count": 8,
2021-02-11 09:00:27 +01:00
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
2021-03-12 12:12:26 +01:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAX8AAAD4CAYAAAAEhuazAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAAp4ElEQVR4nO3dd3xUVd7H8c9JQkLvkY4BaYKKCiKIFRtlld21rGVdy6rrrnX1eXaxPoq66iq6q+La3V3XBlZWULpYKaFJhxAQEkpCCwkkpJ3nj7kZZpJJMglTcme+79eLF/eee3Lnl5vkN3fuufd3jLUWERGJLwnRDkBERCJPyV9EJA4p+YuIxCElfxGROKTkLyISh5Ki9cLt27e3aWlp0Xp5ERFXWrx48S5rbeqR7idqyT8tLY309PRovbyIiCsZY34KxX502UdEJA4p+YuIxCElfxGROKTkLyISh5T8RUTikJK/iEgcUvIXEYlDrkv+ew8UM/XH7dEOQ0TE1VyX/H//zmJufXcJO/KKoh2KiIhruS75b91TCEBJWXmUIxERcS/XJf/sfZ7kn7nrQJQjERFxL9cl/woZOQXRDkFExLVcm/xFRKT+XJv8NfG8iEj9uTb5i4hI/Sn5i4jEIdcmf131ERGpP9cm/ybJidEOQUTEtVyb/HXiLyJSf+5N/rruIyJSby5O/tGOQETEvVyb/F+etzHaIYiIuJZrk/92VfUUEak31yZ/ERGpPyV/EZE4pOQvIhKHlPxFROKQ65L/gM4tox2CiIjruS75d23TJNohiIi4nuuSv+/DXbn5h6IXiIiIiwWV/I0xI40x64wxGcaYcQG2dzfGzDXGLDXG/GiMGR36UD2SEo13Oa+wOFwvIyIS02pN/saYRGAiMAroD1xpjOlfqdsDwCRr7UnAFcBLoQ60Qlm56jqIiBypYM78hwAZ1tpMa20x8D4wtlIfC1SMxLYCtoUuRH8DOrcK165FROJGMMm/C7DVZz3LafP1MPBrY0wWMA24PdCOjDE3G2PSjTHpubm59QgXkpMOh6zibiIi9ROqAd8rgX9aa7sCo4G3jTFV9m2tfdVaO9haOzg1NbVeL5SS5LoxahGRBieYTJoNdPNZ7+q0+fotMAnAWvsD0BhoH4oAK7vq1O7e5RfmZKiuv4hIPQST/BcBvY0xPYwxyXgGdKdU6rMFOBfAGHMsnuRfv+s6tUhJOjx945Tl21i3Mz8cLyMiEtNqTf7W2lLgNmA6sAbPXT2rjDHjjTEXO93uAW4yxiwH3gOusxE6JS8vj8SriIjElqRgOllrp+EZyPVte8hneTUwPLShBWf089/wxZ1ncGwnlX0QEQlWTIyevjBnQ7RDEBFxlZhI/hrzFRGpm5hI/iIiUjcxkfy37j0Y7RBERFwlJpL/yuz90Q5BRMRVYiL5i4hI3Sj5i4jEISV/EZE4FDPJ/2BxabRDEBFxjZhJ/jf9Oz3aIYiIuEbMJP/vMnZHOwQREdeImeQvIiLBU/IXEYlDSv4iInFIyV9EJA65MvkP6dE22iGIiLiaK5P/KWltoh2CiIiruTL5JxgTsL3nvVO5e9KyyAYjIuJCrkz+pprkX27h4yXZEY5GRMR9XJn8E6tJ/iIiEhxXJv/GjVwZtohIg+HKLHrtaWnRDkFExNVcmfwbN0qMdggiIq7myuQvIiJHJiaT/5rtmtNXRKQmrk3+J3RtVe22UX//hi9Xbo9gNCIi7uLa5P/kL0+ocfs3G3ZFKBIREfdxbfKvjXX+X71tP8OfnMPeA8VRjUdEpCGJ2eRfYeJXGWTvK+SbDH0SEBGpELPJ31rPub+ptC4iIi5O/p1bN65x+3sLt1JcWl5tHSARkXiWFO0A6qt10+Ra+/z6jQUs3LQHAJ34i4gc5toz/2BUJH4Ai7K/iEgFVyf/RfefF+0QRERcydXJP7VFStB9rYVvN+yiqKQsjBGJiLiDq5N/Xazbmc+v31jAw1NWRTsUEZGoCyr5G2NGGmPWGWMyjDHjqulzuTFmtTFmlTHm3dCGWb2UpODev/IOlgCQkVMQznBERFyh1sxpjEkEJgKjgP7AlcaY/pX69AbuBYZbawcAd4U+1MAeHXtcUP02OElfw74iIsGd+Q8BMqy1mdbaYuB9YGylPjcBE621ewGstTmhDbN6l5/SLah+i3/aG+ZIRETcI5jk3wXY6rOe5bT56gP0McZ8Z4yZb4wZGWhHxpibjTHpxpj03Nzc+kV8hPSkr4hI6AZ8k4DewNnAlcBrxpjWlTtZa1+11g621g5OTU0N0UvXzZIt+6LyuiIiDUkwyT8b8L220tVp85UFTLHWllhrNwHr8bwZiIhIAxRM8l8E9DbG9DDGJANXAFMq9fkUz1k/xpj2eC4DZYYuzPDJLyqhtKw82mGIiERUrcnfWlsK3AZMB9YAk6y1q4wx440xFzvdpgO7jTGrgbnA/1prd4cr6FApL7cc//AM7vxgWbRDERGJqKAKu1lrpwHTKrU95LNsgbudf67x7x82AzD1x+1MvCq6sYiIRFLcPOHra+5az52o2fsKoxyJiEh0xETyv2xQ1zr1v+O9pRwsLuW1bzb5tX+/cRdp46ayMjsvlOGJiDQ4MZH86yr/UCnLtu6r0j5rtecTwfzMBj9cISJyRGIi+dfrsS096yUicSwmkn95PZ7aPVDsX9o5bdxUDhaXhiokEZEGLSaSf33c9O/0Km3vL9oaoKeISOyJjeSvSzgiInUSG8lfRETqJCaSf8dWjaMdgoiIq8RE8r/rvD7RDkFExFViIvknBzmVo4iIeChrhtCGnfms35kf7TBERGql5B/Ayuw874TvdXH+c19zwXNfhyEiEZHQUvIP4NNl2xg4fgaZuQVV3gSKS8tZoPIPIuJySv41GDFhHqc+Mcuv7bZ3l/CrV+ezZvv+KEUlInLklPxrUVRSzn2frPCuz1i9E4D0zXuiFZKIyBGLueR/x7mhnzr43QVbqrQ9+NmqkL+OiEikxFzyv/60tGiHICLS4AU1jaMbzLnnLA4Wl2FMtCMREWn4Yib590xtDkBeYd1v0RQRiTcxd9kn0mf+2/YVkjZuKtNWbI/sC4uIHIGYS/4pYSr18I+vNmIDTBpz2pNzAHjos5VheV0RkXCIweSfyIL7zqVRYmg/Ajz15Vpe+mpjSPcpIhItMZf8ATq0bMznt58R8v0+PX1dtds01iAibhKTyR+gb8cWfPT70yL2eiVlmk5MRNwjZpM/wKCj20Q7BBGRBimmk3+4DXl8Vo3bN+YW8PzsDRGKRkQkeEr+RyAn/1DA9u15hQBc9dp8np25nr0HiiMZlohIrZT8w2DKsm0AHCotB0CjASLS0Cj5h8Hm3QfYuucgqjQhIg1VzJR3aEjeW7iV9xZupU3TRgABHw4TEYkmnfmHkVGVORFpoJT8I8D3vN9aS1m5PgmISHQp+YfRHucun9/+cxGfLcsG4I1vN3HMfdN0B5CIRFVcJf/xYwdE5XWXZ+Vx5/vLAJicngXAjv1FUYlFRATiJPkP7NaaaXecwW+GpUU7lColp/OLSig4VMqugsDPDIiIhENQd/sYY0YCfwcSgdettU9W0+8S4EPgFGttesiiPAIz/ngmnVo1pkXjRtEOhc+WZbN2Rz4A1sKh0jKOf3iGd/v0u86kb8cWAKzZvp9+HVto0FhEwqLWM39jTCIwERgF9AeuNMb0D9CvBXAnsCDUQR6JPh1aNIjED3gv/QBYLA9P8Z8EfmNuAQAzV+9k1N+/4ZOl2ZEMT0TiSDBn/kOADGttJoAx5n1gLLC6Ur9HgaeA/w1phDFqzPPfVmmreBwgI8fzJrBuZ34kQxKROBLMNf8uwFaf9SynzcsYczLQzVo7taYdGWNuNsakG2PSc3Nz6xxsrLPOTaFWBSFEJMyOeMDXGJMAPAvcU1tfa+2r1trB1trBqampR/rSMafy7f/rd+jMX0TCI5jknw1081nv6rRVaAEcB3xljNkMDAWmGGMGhyrIUGrfPIXWTRvGGEBls9fsZGNugffyz9x1+nQkIuERTPJfBPQ2xvQwxiQDVwBTKjZaa/Oste2ttWnW2jRgPnBxQ7n
2021-02-11 09:00:27 +01:00
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.plot(loss_value)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"ename": "NameError",
"evalue": "name 'net' is not defined",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-5-47d0cd6ed22d>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mfile_path\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"./net_trainned6000_MB_102\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0msave_net\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfile_path\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mnet\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;31mNameError\u001b[0m: name 'net' is not defined"
]
}
],
"source": [
"file_path=\"./net_trainned6000_MB_102\"\n",
"save_net(file_path,net)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
}
},
"nbformat": 4,
"nbformat_minor": 4
}