620 lines
40 KiB
Text
620 lines
40 KiB
Text
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 1,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"#Tous les codes sont basés sur l'environnement suivant\n",
|
||
"#python 3.7\n",
|
||
"#opencv 3.1.0\n",
|
||
"#pytorch 1.4.0\n",
|
||
"\n",
|
||
"import torch\n",
|
||
"from torch.autograd import Variable\n",
|
||
"import torch.nn as nn\n",
|
||
"import torch.nn.functional as F\n",
|
||
"import cv2\n",
|
||
"import matplotlib.pyplot as plt\n",
|
||
"import numpy as np\n",
|
||
"import random\n",
|
||
"import math\n",
|
||
"import pickle\n",
|
||
"import random\n",
|
||
"from PIL import Image\n",
|
||
"import sys"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 2,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"#Les fonctions dans ce bloc ne sont pas utilisées par le réseau, mais certaines fonctions d'outils\n",
|
||
"\n",
|
||
"def tensor_imshow(im_tensor,cannel):\n",
|
||
" b,c,h,w=im_tensor.shape\n",
|
||
" if c==1:\n",
|
||
" plt.imshow(im_tensor.squeeze().detach().numpy())\n",
|
||
" else:\n",
|
||
" plt.imshow(im_tensor.squeeze().detach().numpy()[cannel,:])\n",
|
||
"\n",
|
||
"# Obtenez des données d'entraînement\n",
|
||
"# frag,vt=get_training_fragment(frag_size,image)\n",
|
||
"# frag est un patch carrée de taille (frag_size*frag_size) a partir du image(Son emplacement est aléatoire)\n",
|
||
"# vt est la vérité terrain de la forme Dirac.\n",
|
||
"def get_training_fragment(frag_size,im):\n",
|
||
" h,w,c=im.shape\n",
|
||
" n=random.randint(0,int(h/frag_size)-1)\n",
|
||
" m=random.randint(0,int(w/frag_size)-1) \n",
|
||
" shape=frag_size/4\n",
|
||
" vt_h=math.ceil((h+1)/shape)\n",
|
||
" vt_w=math.ceil((w+1)/shape)\n",
|
||
" vt=np.zeros([vt_h,vt_w])\n",
|
||
" vt_h_po=round((vt_h-1)*(n*frag_size/(h-1)+(n+1)*frag_size/(h-1))/2)\n",
|
||
" vt_w_po=round((vt_w-1)*(m*frag_size/(w-1)+(m+1)*frag_size/(w-1))/2)\n",
|
||
" vt[vt_h_po,vt_w_po]=1\n",
|
||
" vt = np.float32(vt)\n",
|
||
" vt=torch.from_numpy(vt.reshape(1,1,vt_h,vt_w))\n",
|
||
" \n",
|
||
" return im[n*frag_size:(n+1)*frag_size,m*frag_size:(m+1)*frag_size,:],vt\n",
|
||
"\n",
|
||
"# Cette fonction convertit l'image en variable de type Tensor.\n",
|
||
"# Toutes les données de calcul du réseau sont de type Tensor\n",
|
||
"# Img.shape=[Height,Width,Channel]\n",
|
||
"# Tensor.shape=[Batch,Channel,Height,Width]\n",
|
||
"def img2tensor(im):\n",
|
||
" im=np.array(im,dtype=\"float32\")\n",
|
||
" tensor_cv = torch.from_numpy(np.transpose(im, (2, 0, 1)))\n",
|
||
" im_tensor=tensor_cv.unsqueeze(0)\n",
|
||
" return tensor_cv\n",
|
||
"\n",
|
||
"# Trouvez les coordonnées de la valeur maximale dans une carte de corrélation\n",
|
||
"# x,y=show_coordonnee(carte de corrélation)\n",
|
||
"def show_coordonnee(position_pred):\n",
|
||
" map_corre=position_pred.squeeze().detach().numpy()\n",
|
||
" h,w=map_corre.shape\n",
|
||
" max_value=map_corre.max()\n",
|
||
" coordonnee=np.where(map_corre==max_value)\n",
|
||
" return coordonnee[0].mean()/h,coordonnee[1].mean()/w\n",
|
||
"\n",
|
||
"# Filtrer les patchs en fonction du nombre de pixels noirs dans le patch\n",
|
||
"# Si seuls les pixels non noirs sont plus grands qu'une certaine proportion(seuillage), revenez à True, sinon False\n",
|
||
"def test_fragment32_32(frag,seuillage):\n",
|
||
" a=frag[:,:,0]+frag[:,:,1]+frag[:,:,2]\n",
|
||
" mask = (a == 0)\n",
|
||
" arr_new = a[mask]\n",
|
||
" if arr_new.size/a.size<=(1-seuillage):\n",
|
||
" return True\n",
|
||
" else:\n",
|
||
" return False\n",
|
||
" \n",
|
||
"# Ces deux fonctions permettent de sauvegarder le réseau dans un fichier\n",
|
||
"# ou de load le réseau stocké à partir d'un fichier\n",
|
||
"def save_net(file_path,net):\n",
|
||
" pkl_file = open(file_path, 'wb')\n",
|
||
" pickle.dump(net,pkl_file)\n",
|
||
" pkl_file.close()\n",
|
||
"def load_net(file_path): \n",
|
||
" pkl_file = open(file_path, 'rb')\n",
|
||
" net= pickle.load(pkl_file)\n",
|
||
" pkl_file.close()\n",
|
||
" return net"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 3,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Créer un poids de type DeepMatch comme valeur initiale de Conv1 (non obligatoire)\n",
|
||
"def ini():\n",
|
||
" kernel=torch.zeros([8,3,3,3])\n",
|
||
" array_0=np.array([[1,2,1],[0,0,0],[-1,-2,-1]],dtype='float32')\n",
|
||
" array_1=np.array([[2,1,0],[1,0,-1],[0,-1,-2]],dtype='float32')\n",
|
||
" array_2=np.array([[1,0,-1],[2,0,-2],[1,0,-1]],dtype='float32')\n",
|
||
" array_3=np.array([[0,-1,-2],[1,0,-1],[2,1,0]],dtype='float32')\n",
|
||
" array_4=np.array([[-1,-2,-1],[0,0,0],[1,2,1]],dtype='float32')\n",
|
||
" array_5=np.array([[-2,-1,0],[-1,0,1],[0,1,2]],dtype='float32')\n",
|
||
" array_6=np.array([[-1,0,1],[-2,0,2],[-1,0,1]],dtype='float32')\n",
|
||
" array_7=np.array([[0,1,2],[-1,0,1],[-2,-1,0]],dtype='float32')\n",
|
||
" for i in range(3):\n",
|
||
" kernel[0,i,:]=torch.from_numpy(array_0)\n",
|
||
" kernel[1,i,:]=torch.from_numpy(array_1)\n",
|
||
" kernel[2,i,:]=torch.from_numpy(array_2)\n",
|
||
" kernel[3,i,:]=torch.from_numpy(array_3)\n",
|
||
" kernel[4,i,:]=torch.from_numpy(array_4)\n",
|
||
" kernel[5,i,:]=torch.from_numpy(array_5)\n",
|
||
" kernel[6,i,:]=torch.from_numpy(array_6)\n",
|
||
" kernel[7,i,:]=torch.from_numpy(array_7)\n",
|
||
" return torch.nn.Parameter(kernel,requires_grad=True) \n",
|
||
"\n",
|
||
"# Calculer le poids initial de la couche convolutive add\n",
|
||
"# n, m signifie qu'il y a n * m sous-patches dans le patch d'entrée\n",
|
||
"# Par exemple, le patch d'entrée est 16 * 16, pour les patchs 4 * 4 de la première couche, n = 4, m = 4\n",
|
||
"# pour les patchs 8 * 8 de la deuxième couche, n = 2, m = 2\n",
|
||
"def kernel_add_ini(n,m):\n",
|
||
" input_canal=int(n*m)\n",
|
||
" output_canal=int(n/2)*int(m/2)\n",
|
||
" for i in range(int(n/2)):\n",
|
||
" for j in range(int(m/2)):\n",
|
||
" kernel_add=np.zeros([1,input_canal],dtype='float32')\n",
|
||
" kernel_add[0,i*2*m+j*2]=1\n",
|
||
" kernel_add[0,i*2*m+j*2+1]=1\n",
|
||
" kernel_add[0,(i*2+1)*m+j*2]=1\n",
|
||
" kernel_add[0,(i*2+1)*m+j*2+1]=1\n",
|
||
" if i==0 and j==0:\n",
|
||
" add=torch.from_numpy(kernel_add.reshape(1,input_canal,1,1))\n",
|
||
" else:\n",
|
||
" add_=torch.from_numpy(kernel_add.reshape(1,input_canal,1,1))\n",
|
||
" add=torch.cat((add,add_),0)\n",
|
||
" return torch.nn.Parameter(add,requires_grad=False) \n",
|
||
"\n",
|
||
"# Calculer le poids initial de la couche convolutive shift\n",
|
||
"# shift+add Peut réaliser l'étape de l'agrégation\n",
|
||
"# Voir ci-dessus pour les paramètres n et m. \n",
|
||
"# Pour des étapes plus détaillées, veuillez consulter mon rapport de stage\n",
|
||
"def kernel_shift_ini(n,m):\n",
|
||
" input_canal=int(n*m)\n",
|
||
" output_canal=int(n*m)\n",
|
||
" \n",
|
||
" kernel_shift=torch.zeros([output_canal,input_canal,3,3])\n",
|
||
" \n",
|
||
" array_0=np.array([[1,0,0],[0,0,0],[0,0,0]],dtype='float32')\n",
|
||
" array_1=np.array([[0,0,1],[0,0,0],[0,0,0]],dtype='float32')\n",
|
||
" array_2=np.array([[0,0,0],[0,0,0],[1,0,0]],dtype='float32')\n",
|
||
" array_3=np.array([[0,0,0],[0,0,0],[0,0,1]],dtype='float32')\n",
|
||
" \n",
|
||
" kernel_shift_0=torch.from_numpy(array_0)\n",
|
||
" kernel_shift_1=torch.from_numpy(array_1)\n",
|
||
" kernel_shift_2=torch.from_numpy(array_2)\n",
|
||
" kernel_shift_3=torch.from_numpy(array_3)\n",
|
||
" \n",
|
||
" \n",
|
||
" for i in range(n):\n",
|
||
" for j in range(m):\n",
|
||
" if i==0 and j==0:\n",
|
||
" kernel_shift[0,0,:]=kernel_shift_0\n",
|
||
" else:\n",
|
||
" if i%2==0 and j%2==0:\n",
|
||
" kernel_shift[i*m+j,i*m+j,:]=kernel_shift_0\n",
|
||
" if i%2==0 and j%2==1:\n",
|
||
" kernel_shift[i*m+j,i*m+j,:]=kernel_shift_1\n",
|
||
" if i%2==1 and j%2==0:\n",
|
||
" kernel_shift[i*m+j,i*m+j,:]=kernel_shift_2\n",
|
||
" if i%2==1 and j%2==1:\n",
|
||
" kernel_shift[i*m+j,i*m+j,:]=kernel_shift_3\n",
|
||
" \n",
|
||
" return torch.nn.Parameter(kernel_shift,requires_grad=False) \n",
|
||
"\n",
|
||
"# Trouvez le petit patch(4 * 4) dans la n ème ligne et la m ème colonne du patch d'entrée\n",
|
||
"# Ceci est utilisé pour calculer la convolution et obtenir la carte de corrélation\n",
|
||
"def get_patch(fragment,psize,n,m):\n",
|
||
" return fragment[:,:,n*psize:(n+1)*psize,m*psize:(m+1)*psize]\n",
|
||
"\n",
|
||
"###################################################################################################################\n",
|
||
"class Net(nn.Module):\n",
|
||
" def __init__(self,frag_size,psize):\n",
|
||
" super(Net, self).__init__()\n",
|
||
" \n",
|
||
" h_fr=frag_size\n",
|
||
" w_fr=frag_size\n",
|
||
" \n",
|
||
" n=int(h_fr/psize) # n*m patches dans le patch d'entrée\n",
|
||
" m=int(w_fr/psize)\n",
|
||
" \n",
|
||
" self.conv1 = nn.Conv2d(3,8,kernel_size=3,stride=1,padding=1)\n",
|
||
" # Si vous souhaitez initialiser Conv1 avec les poids de DeepMatch, exécutez la ligne suivante\n",
|
||
" # self.conv1.weight=ini()\n",
|
||
" self.Relu = nn.ReLU(inplace=True)\n",
|
||
" self.maxpooling=nn.MaxPool2d(3,stride=2, padding=1)\n",
|
||
" \n",
|
||
" self.shift1=nn.Conv2d(n*m,n*m,kernel_size=3,stride=1,padding=1)\n",
|
||
" self.shift1.weight=kernel_shift_ini(n,m)\n",
|
||
" self.add1 = nn.Conv2d(n*m,int(n/2)*int(m/2),kernel_size=1,stride=1,padding=0)\n",
|
||
" self.add1.weight=kernel_add_ini(n,m)\n",
|
||
" \n",
|
||
" n=int(n/2)\n",
|
||
" m=int(m/2)\n",
|
||
" if n>=2 and m>=2:# Si n=m=1,Notre réseau n'a plus besoin de plus de couches pour agréger les cartes de corrélation\n",
|
||
" self.shift2=nn.Conv2d(n*m,n*m,kernel_size=3,stride=1,padding=1)\n",
|
||
" self.shift2.weight=kernel_shift_ini(n,m)\n",
|
||
" self.add2 = nn.Conv2d(n*m,int(n/2)*int(m/2),kernel_size=1,stride=1,padding=0)\n",
|
||
" self.add2.weight=kernel_add_ini(n,m)\n",
|
||
" \n",
|
||
" n=int(n/2)\n",
|
||
" m=int(m/2)\n",
|
||
" if n>=2 and m>=2:\n",
|
||
" self.shift3=nn.Conv2d(n*m,n*m,kernel_size=3,stride=1,padding=1)\n",
|
||
" self.shift3.weight=kernel_shift_ini(n,m)\n",
|
||
" self.add3 = nn.Conv2d(n*m,int(n/2)*int(m/2),kernel_size=1,stride=1,padding=0)\n",
|
||
" self.add3.weight=kernel_add_ini(n,m)\n",
|
||
" \n",
|
||
" def get_descripteur(self,img,using_cuda):\n",
|
||
" # Utilisez Conv1 pour calculer le descripteur,\n",
|
||
" descripteur_img=self.Relu(self.conv1(img))\n",
|
||
" b,c,h,w=descripteur_img.shape\n",
|
||
" couche_constante = 0.5 * torch.ones([b, 1, h, w])\n",
|
||
" if using_cuda:\n",
|
||
" couche_constante=couche_constante.cuda()\n",
|
||
" # Ajouter une couche constante pour éviter la division par 0 lors de la normalisation\n",
|
||
" descripteur_img = torch.cat((descripteur_img,couche_constante),1)\n",
|
||
" # la normalisation\n",
|
||
" descripteur_img_norm = F.normalize(descripteur_img) #/torch.norm(descripteur_img,dim=1, keepdim = True)\n",
|
||
" return descripteur_img_norm\n",
|
||
" \n",
|
||
" def forward(self,img,frag,using_cuda):\n",
|
||
" psize=4\n",
|
||
" # Utilisez Conv1 pour calculer le descripteur,\n",
|
||
" descripteur_input2=self.get_descripteur(frag,using_cuda)\n",
|
||
" descripteur_input1=self.get_descripteur(img,using_cuda)\n",
|
||
" \n",
|
||
" b,c,h,w=frag.shape\n",
|
||
" n=int(h/psize)\n",
|
||
" m=int(w/psize)\n",
|
||
" \n",
|
||
" db,dc,dh,dw = descripteur_input1.shape\n",
|
||
" \n",
|
||
" #######################################\n",
|
||
" # Calculer la carte de corrélation par convolution pour les n*m patchs plus petit.\n",
|
||
" for i in range(n):\n",
|
||
" for j in range(m):\n",
|
||
" if i==0 and j==0:\n",
|
||
" ##HAD TO CHANGE THIS LINE BECAUSE OF CONVOLUTION DIMENSION FOR BATCHES\n",
|
||
" map_corre=F.conv2d(descripteur_input1.view(1,db*dc,dh,dw),get_patch(descripteur_input2,psize,i,j),padding=2,groups=db)\n",
|
||
"\n",
|
||
" map_corre=map_corre.view(db,1,map_corre.size(2),map_corre.size(3))\n",
|
||
" else:\n",
|
||
" a=F.conv2d(descripteur_input1.view(1,db*dc,dh,dw),get_patch(descripteur_input2,psize,i,j),padding=2, groups=db)\n",
|
||
" a=a.view(db,1,a.size(2),a.size(3))\n",
|
||
" map_corre=torch.cat((map_corre,a),1)\n",
|
||
" \n",
|
||
" ########################################\n",
|
||
" # Étape de polymérisation\n",
|
||
" map_corre=self.maxpooling(map_corre)\n",
|
||
" map_corre=self.shift1(map_corre)\n",
|
||
" map_corre=self.add1(map_corre)\n",
|
||
" \n",
|
||
" #########################################\n",
|
||
" # Répétez l'étape d'agrégation jusqu'à obtenir le graphique de corrélation du patch d'entrée\n",
|
||
" n=int(n/2)\n",
|
||
" m=int(m/2)\n",
|
||
" if n>=2 and m>=2:\n",
|
||
" map_corre=self.maxpooling(map_corre)\n",
|
||
" map_corre=self.shift2(map_corre)\n",
|
||
" map_corre=self.add2(map_corre)\n",
|
||
" \n",
|
||
" \n",
|
||
" n=int(n/2)\n",
|
||
" m=int(m/2)\n",
|
||
" if n>=2 and m>=2:\n",
|
||
" map_corre=self.maxpooling(map_corre)\n",
|
||
" map_corre=self.shift3(map_corre)\n",
|
||
" map_corre=self.add3(map_corre)\n",
|
||
" \n",
|
||
" \n",
|
||
" #b,c,h,w=map_corre.shape\n",
|
||
" # Normalisation de la division par maximum\n",
|
||
" map_corre=map_corre/map_corre.max()\n",
|
||
" # Normalisation SoftMax\n",
|
||
" #map_corre=(F.softmax(map_corre.reshape(1,1,h*w,1),dim=2)).reshape(b,c,h,w)\n",
|
||
" return map_corre"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"# Dataset and Dataloader"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 4,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"import os\n",
|
||
"import re\n",
|
||
"from PIL import Image\n",
|
||
"from torchvision import transforms\n",
|
||
"from torch.utils.data import Dataset, DataLoader\n",
|
||
"from time import time\n",
|
||
"\n",
|
||
"\n",
|
||
"class FragmentDataset(Dataset):\n",
|
||
" def __init__(\n",
|
||
" self,\n",
|
||
" fragments_path, \n",
|
||
" train, \n",
|
||
" frags_transform=transforms.ToTensor(),\n",
|
||
" fresques_transform=None,\n",
|
||
" vts_transform=None,\n",
|
||
" ):\n",
|
||
" \"\"\"\n",
|
||
" Parameters\n",
|
||
" ----------\n",
|
||
" fragments_path: str\n",
|
||
" Path to root flder with fragments folders.\n",
|
||
" train: boolean\n",
|
||
" True for train set (__dev__) False for test (__bench__)\n",
|
||
" frags_transform: torchvision.transform\n",
|
||
" Tranform to apply to all fragment images. Default: ToTensor()\n",
|
||
" fresques_transform: torchvision.transform\n",
|
||
" Transform to apply to all fresque images. frags_transform if None.\n",
|
||
" vts_transform: transform to apply to all vts images. Default: ToTensor().\n",
|
||
" \"\"\"\n",
|
||
" self.base_path = fragments_path\n",
|
||
" self.frags_transform = frags_transform\n",
|
||
" self.fresques_transform = fresques_transform if fresques_transform else frags_transform\n",
|
||
" self.fragments_list = []\n",
|
||
" self.vts_transform = vts_transform\n",
|
||
" \n",
|
||
" # To separate between train (dev) and test fragments(bench)\n",
|
||
" self.match_expr = \"_dev_\" if train else \"_bench_\"\n",
|
||
" \n",
|
||
" fragments_path = os.path.join(self.base_path, \"fragments\")\n",
|
||
" for fresque_dir in os.listdir(fragments_path):\n",
|
||
" current_path = os.path.join(fragments_path, fresque_dir)\n",
|
||
" \n",
|
||
" if \"fresque\" in current_path: \n",
|
||
" # Avoids looking at extra files in the dirs.\n",
|
||
" \n",
|
||
" # Get path to current fresque (ie: ..path/fresque0.ppm).\n",
|
||
" fresque_name = current_path.split(\"/\")[-1] + \".ppm\"\n",
|
||
" full_fresque_path = os.path.join(self.base_path, fresque_name) \n",
|
||
" \n",
|
||
" # Get path to every fragment for that fresque (ie: ..path/fresque0/frag_bench_000.ppm)\n",
|
||
" all_fragments_fresque = sorted(os.listdir(current_path))\n",
|
||
" \n",
|
||
" #Get path to every vt for that fresque (ie: ..path/fresque0/vt/frag_bench_000.ppm))\n",
|
||
" vts_path = os.path.join(current_path, \"vt\")\n",
|
||
" all_vts_fresque = sorted(os.listdir(vts_path))\n",
|
||
" \n",
|
||
" # Keep fragments that belong in that set (Train | Test) \n",
|
||
" # group them with the full fresque path (tuple)\n",
|
||
" all_fragments_fresque = [\n",
|
||
" (os.path.join(current_path, frag_path), full_fresque_path, os.path.join(vts_path, vt_path))\n",
|
||
" for frag_path, vt_path in zip(all_fragments_fresque, all_vts_fresque)\n",
|
||
" if re.search(self.match_expr, frag_path) and re.search(self.match_expr, vt_path)\n",
|
||
" ]\n",
|
||
" \n",
|
||
" self.fragments_list.extend(all_fragments_fresque)\n",
|
||
" \n",
|
||
" def __len__(self):\n",
|
||
" return len(self.fragments_list)\n",
|
||
" \n",
|
||
" def __getitem__(self, idx):\n",
|
||
" # Loads the fragment and the full fresque as a tensor.\n",
|
||
" fragment = Image.open(self.fragments_list[idx][0])\n",
|
||
" fresque = Image.open(self.fragments_list[idx][1])\n",
|
||
" \n",
|
||
" with open(self.fragments_list[idx][2],'r') as f:\n",
|
||
" data_vt_raw = f.readlines()\n",
|
||
" data_vt = [int(d.rstrip('\\r\\n')) for d in data_vt_raw]\n",
|
||
" \n",
|
||
" # Construct vt\n",
|
||
" div = 8\n",
|
||
" vt = np.zeros((int(data_vt[0]/div)+1,int(data_vt[1]/div)+1))\n",
|
||
" vt[int(data_vt[2]/div),int(data_vt[3]/div)] = 1\n",
|
||
" vt = np.float32(vt)\n",
|
||
" vt = torch.from_numpy(vt.reshape(1,vt.shape[0],vt.shape[0]))\n",
|
||
" #vt = torch.from_numpy(vt.reshape(1,int(data_vt[0]/div)+1,int(data_vt[1]/div)+1))\n",
|
||
" \n",
|
||
" return self.frags_transform(fragment), self.fresques_transform(fresque), vt"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Usage"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 5,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"fresques_tnsf = transforms.Compose([\n",
|
||
" transforms.Resize((1000, 1000)),\n",
|
||
" transforms.ToTensor()\n",
|
||
"])\n",
|
||
"\n",
|
||
"train = FragmentDataset(fragments_path=\"training_data_32\", train=True, fresques_transform=fresques_tnsf)\n",
|
||
"test = FragmentDataset(fragments_path=\"training_data_32\", train=False, fresques_transform=fresques_tnsf)\n",
|
||
"bs = 2\n",
|
||
"train_loader = DataLoader(train, batch_size=bs, num_workers=6, pin_memory=True, shuffle = True)\n",
|
||
"test_loader = DataLoader(test, batch_size=bs, num_workers=6, pin_memory=True, shuffle = True)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Train"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 6,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Temps par batch: 0.291\n",
|
||
"Temps par batch: 6.76\n",
|
||
"Temps par batch: 5.59\n",
|
||
"Temps par batch: 5.6\n"
|
||
]
|
||
},
|
||
{
|
||
"ename": "KeyboardInterrupt",
|
||
"evalue": "",
|
||
"output_type": "error",
|
||
"traceback": [
|
||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
|
||
"\u001b[0;32m<ipython-input-6-7e6697633d7e>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 32\u001b[0m \u001b[0;32mdel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfragments\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 33\u001b[0m \u001b[0;32mdel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfresques\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 34\u001b[0;31m \u001b[0mvts\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mvts\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcuda\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 35\u001b[0m \u001b[0mcost\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mloss_func\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpreds\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvts\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 36\u001b[0m \u001b[0mcost\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"frag_size = 32\n",
|
||
"psize = 4\n",
|
||
"num_epochs = 6\n",
|
||
"\n",
|
||
"net = Net(frag_size, psize).cuda()\n",
|
||
" \n",
|
||
"optimizer = torch.optim.Adam(net.parameters(), lr=0.001)\n",
|
||
"loss_func = torch.nn.MSELoss()\n",
|
||
"\n",
|
||
"\n",
|
||
"\n",
|
||
"loss_value = []\n",
|
||
"para_value = []\n",
|
||
"\n",
|
||
"time_old = time()\n",
|
||
"\n",
|
||
"torch.cuda.empty_cache\n",
|
||
"optimizer.zero_grad()\n",
|
||
"\n",
|
||
"for epoch in range(num_epochs):\n",
|
||
" for fragments, fresques, vts in train_loader:\n",
|
||
" \n",
|
||
" print(\"Temps par batch: {:.3}\".format(time()-time_old))\n",
|
||
" time_old = time()\n",
|
||
"\n",
|
||
" fragments = fragments.cuda()\n",
|
||
" fresques = fresques.cuda()\n",
|
||
" \n",
|
||
" preds = net(fresques, fragments, True) \n",
|
||
" optimizer.zero_grad()\n",
|
||
" \n",
|
||
" del(fragments)\n",
|
||
" del(fresques)\n",
|
||
" vts = vts.cuda()\n",
|
||
" cost = loss_func(preds, vts)\n",
|
||
" cost.backward()\n",
|
||
" del(vts)\n",
|
||
" optimizer.step()\n",
|
||
"\n",
|
||
" loss_value.append(cost.item())\n",
|
||
" torch.cuda.empty_cache \n",
|
||
" print('Done with epoch ', epoch)\n",
|
||
" # sauvegarder le réseau dans le fichier \"net_trainned6000\"\n",
|
||
"file_path=\"./net_trainned_from-random_full-dataset-small_02-10_07-30_0107\"\n",
|
||
"save_net(file_path,net)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 7,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"13500"
|
||
]
|
||
},
|
||
"execution_count": 7,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"len(loss_value)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 8,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"[<matplotlib.lines.Line2D at 0x7f411022f2e0>]"
|
||
]
|
||
},
|
||
"execution_count": 8,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
},
|
||
{
|
||
"data": {
|
||
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAX8AAAD4CAYAAAAEhuazAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAAp4ElEQVR4nO3dd3xUVd7H8c9JQkLvkY4BaYKKCiKIFRtlld21rGVdy6rrrnX1eXaxPoq66iq6q+La3V3XBlZWULpYKaFJhxAQEkpCCwkkpJ3nj7kZZpJJMglTcme+79eLF/eee3Lnl5vkN3fuufd3jLUWERGJLwnRDkBERCJPyV9EJA4p+YuIxCElfxGROKTkLyISh5Ki9cLt27e3aWlp0Xp5ERFXWrx48S5rbeqR7idqyT8tLY309PRovbyIiCsZY34KxX502UdEJA4p+YuIxCElfxGROKTkLyISh5T8RUTikJK/iEgcUvIXEYlDrkv+ew8UM/XH7dEOQ0TE1VyX/H//zmJufXcJO/KKoh2KiIhruS75b91TCEBJWXmUIxERcS/XJf/sfZ7kn7nrQJQjERFxL9cl/woZOQXRDkFExLVcm/xFRKT+XJv8NfG8iEj9uTb5i4hI/Sn5i4jEIdcmf131ERGpP9cm/ybJidEOQUTEtVyb/HXiLyJSf+5N/rruIyJSby5O/tGOQETEvVyb/F+etzHaIYiIuJZrk/92VfUUEak31yZ/ERGpPyV/EZE4pOQvIhKHlPxFROKQ65L/gM4tox2CiIjruS75d23TJNohiIi4nuuSv+/DXbn5h6IXiIiIiwWV/I0xI40x64wxGcaYcQG2dzfGzDXGLDXG/GiMGR36UD2SEo13Oa+wOFwvIyIS02pN/saYRGAiMAroD1xpjOlfqdsDwCRr7UnAFcBLoQ60Qlm56jqIiBypYM78hwAZ1tpMa20x8D4wtlIfC1SMxLYCtoUuRH8DOrcK165FROJGMMm/C7DVZz3LafP1MPBrY0wWMA24PdCOjDE3G2PSjTHpubm59QgXkpMOh6zibiIi9ROqAd8rgX9aa7sCo4G3jTFV9m2tfdVaO9haOzg1NbVeL5SS5LoxahGRBieYTJoNdPNZ7+q0+fotMAnAWvsD0BhoH4oAK7vq1O7e5RfmZKiuv4hIPQST/BcBvY0xPYwxyXgGdKdU6rMFOBfAGHMsnuRfv+s6tUhJOjx945Tl21i3Mz8cLyMiEtNqTf7W2lLgNmA6sAbPXT2rjDHjjTEXO93uAW4yxiwH3gOusxE6JS8vj8SriIjElqRgOllrp+EZyPVte8hneTUwPLShBWf089/wxZ1ncGwnlX0QEQlWTIyevjBnQ7RDEBFxlZhI/hrzFRGpm5hI/iIiUjcxkfy37j0Y7RBERFwlJpL/yuz90Q5BRMRVYiL5i4hI3Sj5i4jEISV/EZE4FDPJ/2BxabRDEBFxjZhJ/jf9Oz3aIYiIuEbMJP/vMnZHOwQREdeImeQvIiLBU/IXEYlDSv4iInFIyV9EJA65MvkP6dE22iGIiLiaK5P/KWltoh2CiIiruTL5JxgTsL3nvVO5e9KyyAYjIuJCrkz+pprkX27h4yXZEY5GRMR9XJn8E6tJ/iIiEhxXJv/GjVwZtohIg+HKLHrtaWnRDkFExNVcmfwbN0qMdggiIq7myuQvIiJHJiaT/5rtmtNXRKQmrk3+J3RtVe22UX//hi9Xbo9gNCIi7uLa5P/kL0+ocfs3G3ZFKBIREfdxbfKvjXX+X71tP8OfnMPeA8VRjUdEpCGJ2eRfYeJXGWTvK+SbDH0SEBGpELPJ31rPub+ptC4iIi5O/p1bN65x+3sLt1JcWl5tHSARkXiWFO0A6qt10+Ra+/z6jQUs3LQHAJ34i4gc5toz/2BUJH4Ai7K/iEgFVyf/RfefF+0QRERcydXJP7VFStB9rYVvN+yiqKQsjBGJiLiDq5N/Xazbmc+v31jAw1NWRTsUEZGoCyr5G2NGGmPWGWMyjDHjqulzuTFmtTFmlTHm3dCGWb2UpODev/IOlgCQkVMQznBERFyh1sxpjEkEJgKjgP7AlcaY/pX69AbuBYZbawcAd4U+1MAeHXtcUP02OElfw74iIsGd+Q8BMqy1mdbaYuB9YGylPjcBE621ewGstTmhDbN6l5/SLah+i3/aG+ZIRETcI5jk3wXY6rOe5bT56gP0McZ8Z4yZb4wZGWhHxpibjTHpxpj03Nzc+kV8hPSkr4hI6AZ8k4DewNnAlcBrxpjWlTtZa1+11g621g5OTU0N0UvXzZIt+6LyuiIiDUkwyT8b8L220tVp85UFTLHWllhrNwHr8bwZiIhIAxRM8l8E9DbG9DDGJANXAFMq9fkUz1k/xpj2eC4DZYYuzPDJLyqhtKw82mGIiERUrcnfWlsK3AZMB9YAk6y1q4wx440xFzvdpgO7jTGrgbnA/1prd4cr6FApL7cc//AM7vxgWbRDERGJqKAKu1lrpwHTKrU95LNsgbudf67x7x82AzD1x+1MvCq6sYiIRFLcPOHra+5az52o2fsKoxyJiEh0xETyv2xQ1zr1v+O9pRwsLuW1bzb5tX+/cRdp46ayMjsvlOGJiDQ4MZH86yr/UCnLtu6r0j5rtecTwfzMBj9cISJyRGIi+dfrsS096yUicSwmkn95PZ7aPVDsX9o5bdxUDhaXhiokEZEGLSaSf33c9O/0Km3vL9oaoKeISOyJjeSvSzgiInUSG8lfRETqJCaSf8dWjaMdgoiIq8RE8r/rvD7RDkFExFViIvknBzmVo4iIeChrhtCGnfms35kf7TBERGql5B/Ayuw874TvdXH+c19zwXNfhyEiEZHQUvIP4NNl2xg4fgaZuQVV3gSKS8tZoPIPIuJySv41GDFhHqc+Mcuv7bZ3l/CrV+ezZvv+KEUlInLklPxrUVRSzn2frPCuz1i9E4D0zXuiFZKIyBGLueR/x7mhnzr43QVbqrQ9+NmqkL+OiEikxFzyv/60tGiHICLS4AU1jaMbzLnnLA4Wl2FMtCMREWn4Yib590xtDkBeYd1v0RQRiTcxd9kn0mf+2/YVkjZuKtNWbI/sC4uIHIGYS/4pYSr18I+vNmIDTBpz2pNzAHjos5VheV0RkXCIweSfyIL7zqVRYmg/Ajz15Vpe+mpjSPcpIhItMZf8ATq0bMznt58R8v0+PX1dtds01iAibhKTyR+gb8cWfPT70yL2eiVlmk5MRNwjZpM/wKCj20Q7BBGRBimmk3+4DXl8Vo3bN+YW8PzsDRGKRkQkeEr+RyAn/1DA9u15hQBc9dp8np25nr0HiiMZlohIrZT8w2DKsm0AHCotB0CjASLS0Cj5h8Hm3QfYuucgqjQhIg1VzJR3aEjeW7iV9xZupU3TRgABHw4TEYkmnfmHkVGVORFpoJT8I8D3vN9aS1m5PgmISHQp+YfRHucun9/+cxGfLcsG4I1vN3HMfdN0B5CIRFVcJf/xYwdE5XWXZ+Vx5/vLAJicngXAjv1FUYlFRATiJPkP7NaaaXecwW+GpUU7lColp/OLSig4VMqugsDPDIiIhENQd/sYY0YCfwcSgdettU9W0+8S4EPgFGttesiiPAIz/ngmnVo1pkXjRtEOhc+WZbN2Rz4A1sKh0jKOf3iGd/v0u86kb8cWAKzZvp9+HVto0FhEwqLWM39jTCIwERgF9AeuNMb0D9CvBXAnsCDUQR6JPh1aNIjED3gv/QBYLA9P8Z8EfmNuAQAzV+9k1N+/4ZOl2ZEMT0TiSDBn/kOADGttJoAx5n1gLLC6Ur9HgaeA/w1phDFqzPPfVmmreBwgI8fzJrBuZ34kQxKROBLMNf8uwFaf9SynzcsYczLQzVo7taYdGWNuNsakG2PSc3Nz6xxsrLPOTaFWBSFEJMyOeMDXGJMAPAvcU1tfa+2r1trB1trBqampR/rSMafy7f/rd+jMX0TCI5jknw1081nv6rRVaAEcB3xljNkMDAWmGGMGhyrIUGrfPIXWTRvGGEBls9fsZGNugffyz9x1+nQkIuERTPJfBPQ2xvQwxiQDVwBTKjZaa/Oste2ttWnW2jRgPnBxQ7nbp7IF953L4gfOj3YYAX22bBvnTpgX7TBEJA7UmvyttaXAbcB0YA0wyVq7yhgz3hhzcbgDDLXEBENigntun1yyZW+0QxCRGBTUff7W2mnAtEptD1XT9+wjD0sq/PKl79n0xGjd7y8iIRUXT/gGkpLknm/9nknLox2CiMQY92TAEPvVKd1q7xQllev/f6yHvUQkxOI2+T/4syoPKTcYz8xYH+0QRCTGxW3yb5R4+Fvv2LJxFCMREYm8uE3+AH8a2ZcbhvegaXJitEOp1Zcrt0c7BBGJIXGd/P9wdi8euqjhXv7xdct/lrAyOy/aYYhIjIjr5F+hXfPkaIcQlIJDpdEOQURihJI/8NLVg6IdQlDWbt/PG99uAqC83FJaVh7liETErZT8gdQWKd7lO0b0om+HFrx53WBuPeeYKEZV1cP/Xc2jn6+mvNxy07/T6XX/F2Q6cwCIiNSFkn8lY0/qwvQ/nsmIfh24fniPaIcT0N9mrWf22hwARkyYp8FgEakzJX9HRfWEHu2aedvaN0+ppnd0PT8nw2/9lv8siVIkIuJWSv4Oq/lTRCSOKPk7Jlw2kO5tm6L6aSISD4Kq6hkPLhnUlUsGdY12GCIiEaEz/1o8MOZYzuqTyobHR0U7lKB9tDiLr9blVCkQJyJSQWf+tbjxjJ7ceEbPaIdRqwOHSmmWksSKrDzumewpAf1/F/VvsHcsiUh06cw/Rjz02SoAdh045G3bkKNnAEQkMCX/GLG7Iun7XOkJdNXnUGkZuwsOVd0gInFFl31ixNrt+Xy4OAvf6Yl9r/lv3XOQDTn5/Gf+FuaszWHzk2OiEKWINBRK/vWwZvxIjn3oy2iH4WfH/iL+Z7L/dI9frtpB344tKCop56kv10YpMhFpiJT86+CsPqns3F9EExfU/wfYd7CER/67OtphiEgDpORfB/+6YUi0QxARCQkl/zi1bV8hew4U0655Mp1aNQHggufmUVpmmfM/Z0c3OBEJOyX/I3RC11b8mOW+GbZOe3KOd7li8Hf9Ts+toXkHS2jVtBEA32XsYvW2/dx0ZsN/1kFEgqdbPY/QJ38YHu0QQu6Sl7/3Ll/9+gIen7YmitGISDjozP8IJSa4vxJcwaFSlm/d513P0MNhIjFPZ/4h8Nmtw3ns58d515OT3HVYd+QVcs+k5bV3FJGY4a4s1UAN7NaaXw892rt+dp/UKEZTd+c9+zU79hf5tS3ctIeJcw9PGjNr9U7Sxk1lf1FJpMMTkTDQZZ96eunqk2ndpFHAbaOO70hRaTlfr88F4M8j+7nuIavLX/nBb/0F541gY04BJ3VvE42QRCSEdOZfT6OP78RpvdoH3Nb7qBY0aXT40P7+7IY1EXx9FBWXRTsEEQkhJf8wOK5LKxJibEqwdTvzATAx9n2JxCsl/xC7xrn2H6s5cvu+Qjbm6m4gEbfTNf8Qynh8lPfWT0NsZv/fv7MEoN5VQees3Umv1BZ0b9e0Tl+XkZPPkp/2cfkp3er1uiLiT2f+IZSUmHD4skil3P/P609hSI+2kQ8qQopLy/nTh8vZkee5a+jHrH18tDirSr8b/pnOORO+qvP+L3jua/700Y/Vbj9UqjEJkbpQ8g+Tyuf9Z/c9imcuHRiVWMJhRVYeP2bt867PXrOTSelZ3D1pGTvyirj4xe+4Z/Jy5q3PZdHmPX5fW1Ze97mFa/qS9Tvz6fvAl3z+47Y671ckXin5h0mgeX+PapkCwPFdWvm1L37gvIjEFEoXvfgtF7/4HdZa9hwopsyZOOb7jbsZ+sRsb79r31zIZS//wBcrtvNdxi5ve/rmPZSWlYcklpXZntpKs1bvDMn+avPNhlyWbtkbkdcSCZegkr8xZqQxZp0xJsMYMy7A9ruNMauNMT8aY2YbY44OtJ94coKT4H0Hfhs3SmTzk2OYcttwVo+/0NvernlKpMMLmRfnZHDyozO57d2lNfb7/TtLuPr1Bd71S1/+gb/P3lCl3879RVhrOVRaxohnvmKe86xETSqOcTCfJ8rLLeU1fIxYkZXHA5+u8JsFrbJr3ljIL17y1D9aumUvw5+co4ffxHVqTf7GmERgIjAK6A9caYzpX6nbUmCwtfYE4EPgr6EO1G0qElK3NlUHNo0xNE32H2tvnuLOsfcJM9fX+2vXO7ePVtiwM59T/zKbV7/OZNu+IjJ3HeD/Plvp16cicR8sLiVt3FTeXbDFO7hu/eYvtny8JIuikjLv+uT0rQwcP4NTHp9VbUxXvz6f/8zfwv7C0qC+h2dnrid7XyFLt+wLqr9IQxHMmf8QIMNam2mtLQbeB8b6drDWzrXWHnRW5wNdQxum+xhjePWaQUy+ZVi1fd6/eShvXjcYgJWPXFhtv1iVYAy7Cw6xapvnss39n3gS/RNfrOXJLzyVRDfvPuj3NZMXb+Xsp+eyc79nEvr7PllBbn7VCennrsvh7knLmTBjHQDpP+3lfz/8kfyiUnYfKK42prqPRnjE5r1dEsuCSf5dgK0+61lOW3V+C3wRaIMx5mZjTLoxJj03t/aP8253wYCOdGjZuNrtQ3u2Y0S/Dt71caP6RSKsBmPWmp0MemwWY57/FoCFPgPD01cFvn7/549WsHn3QQ4WHz4zryg5vWDTbm9bXqHnMkzFG0NhsE8oO9l//Oe1T3854pmv/D5tzF2Xw+vfZDLsidneTxwiDVVIB3yNMb8GBgNPB9purX3VWjvYWjs4NdVdxc8i4ZazjmHh/ecy6+4zeeHKk6IdTtiVlB3OnCeOn1Ftv8npW6u0Xf7yD1XaKj4NAKzM3g8cfiK5cunttHFTGfToTCb57PtQaRn5hzxvKh8t8b9N9R9fbeTGfy3ya8vcdcC7bAxc/9YiHpu6hu15Rbz+TWa134+vN7/dxE+7D9TesQ4+XpJF9r7CkO5TYk8wyT8b8H2ypqvT5scYcx5wP3Cxtbbq53AJylEtGtPrqBZcNLBztEOJqH0Hqx8wffKLqkXxDlRzJr8xt4Bb313CG99uAuCTpdmkjZvqN9hcYfeBYv70oefZgf1FJfR94MuA+7zkH9/z1JdrmbUmh90F/r/atpoLRc/MqH0spOBQKeM/X80Vr86vtW+wikvLuXvS8oBvjiK+gkn+i4Dexpgexphk4Apgim8HY8xJwCt4En9O6MOMTx1apnDj6T3Y/OQY/nrJCdEOJ2pqukZf2bkT5jH1x+112n9mbgEnPFz1k8egR2fyw8bdLP7p8G2dgx7zHyyuuOwT7BPdxaXlTP1xO9Za7x1F+wtDd6dQxZtRoHEQEV+1Jn9rbSlwGzAdWANMstauMsaMN8Zc7HR7GmgOTDbGLDPGTKlmd1IHC+47jwd+5rmxSmUNwmfEhHkB23cfKObK12o+K69I/pMXV700Fchzs9Zz67tL+Gp9rrf4X22DzAs37eG1rzMpKSsPeJvqlt0H/WZiEwlGUPcXWmunAdMqtT3ks+y+p5REQqDYeVBt867qr9vnHSyhaUoijRITyN5b6G3zPp8QIPtPmLGOF+ZkkPmX0d65FR6ftoZfnNSF5351ol/fM5+e64mhnvWWJD7pCV8X+dPIvowb1U9/5A1IxSWh5Vl5VbYt3bIXay0Dx8/gbmeazE3Om8Sh0jJudx6MKyot87tM8/zsDbwwxzN5Tlmld4ZPllYZbhOpFyV/F/nD2b245azAE8Pcc36fCEcjtfnFS9/T417PB+b/Lt/G1j0HWeGUovjv8u3MXusZHrMWTnl8FvuLStiYW8CzR/DgnEiw3PlYqfjpmdqME7u3jnYYUosz/jrXu/ytT52jCoEGnb/fuLtKm0go6MzfpSZedTJvXXcKAD3bN/PbdvnguH/AOmZc++bCKm0Vzybc9O90Zq+pfzG7t77bxMwIFcOThkdn/i415oROALxyzSCGHdPOO8fui1edxE9OSYRrhh5NflEJvzi5K9e+uZCLB3ZmynKVPXa7xz5fzeWDuzFz9U6/5L274JB3ADoYj/zX8xSzxpDik5K/y104oCMALRs38v4RV9zn3q9TC64+1VNgtWKbkn9sODfAhDi+zyAUl5Wzett++nduGcGoxE102ScGjT6+I+/dNJSrhnSPdigSBvuLStmYW3tJiNeCLDEBsPdAMZ86dxLlHSwJ+XMDWXsPVlv2esrybaSNm8q+g8E/zCdHTsk/BhljGHZMu8NTSvqIx+qh8eqTpdlVHgpbvW1/wAfF7nh/KXd9sIzvM3YxduK3jJ34nd/2DxZtCbpeUSCnPzWXUX/7JuC2N51SHJk1PCshoafLPnGm8rwBSx48n5MfnVml36jjOvLFyh2RCkvC5J2FW1i2ZR/9OrbwVj8dktaW+8Yc69evYu7lqyrVQDpUWkZ5uaeaKgSeoS5Y1RWbS/A+7FbfgtpSHzrzj3OVq10O7OqZgexXKicREx78dCUfLcnyJn7wlM7+uc+ZffrmPWzIKajytdZazn/2a4596HDBu5rmXy4sLuPO95eSk19UpxgrPqEG2nXBoVIOHKp5Yh1rLS/M3lDjU9ZSlc7849Dnt5/Oiuw8zuqTSqPEw8n/zesGM6JfB6y1GGMY2rMt8zMP19j/7ek9vNUyJXZcWk0FUGthyx7/yXRe+yaz2gcNr3p9Pku37CPBmColKGpScf6xdvt+NuYUkJt/iOuGp7Fw0x5++690oOY7knILDjFh5nomzFzPFad0497Rx9KqSaOgX786K7PzuOGfi5h+15m0aZZ8xPtraHTmH4eO69KKK4d0p3PrJqQkJXrbKyaWqTgT+9uv/OcUOMnnQbILB3RAYlugc/yKW0tfmL2BP36wzG9bxVSW+UUlbN51IOBlnOdmrq8y0U3FbG0PfraKcR+vYMLM9Rz/8Axv4gdPWYyXvvKUvMgrLGGXb2ltn5d5f9FWBj5S/dwQdTFxbgY5+Ydi9kE7Jf84l5hgaNssmaSEmksSd23ThDHHd2JEv6N4+KL+nNP3qHq/5nnH1v9rJXLyApSarqhlNGHmer86Q1+tO1zJ/buM3Zz9zFd8uDiLtTv2c9u7S7zb/j57A/0e/JKMnAK+37iLtHFTgyo/fcWrP/DXL9eRV1jCoEdnMti3tHaAX90/frDMe4mqqKTMO6YRSFm5ZcmWvVXaveW6Y3SOTl32Eebfe27A9qNapHDZoK78ZlgaxztjAW86TxVn7fWcrV00sDOPXDyARomG4wOUJ/jyrjPofVQL5mfu5urXF/DIxQO4ckh3+jwQcKZPaUAC3QgA/gOzY1/8lksHd+PBT1d62wqdM/u3vttMWbll3c78Kvs479l5tGgcfPo5eOjwp4VSn8GBuyctIy/ARECfLM3mngv60Cw5id+9vZiFm/dUe+lo4twMnp25no9+P4xBR7etsj1Gc7+Sv0ByUuAPgAkJhqcvGxhwW9c2TWu8Djv5lmGktWtGaosUAIb3aq8nSWPERS9+611enpUXsKIpwOrt+6v93QLIL6p5INdXRbqvOOkAz5vQx0uqr3J6+lNz/db3F5WwI6+IlKQEfvf2Yv5z46nMWZPjLba3Iy+4WdrAMxCdvbeQvh1b1Bp7ebllRXYeA7u1rrVvJCn5S8i9/pvBnJJW9QzK16e3Die/qITr31pEabnljN7tWfzTXg5WMz3jHSN68bxT5rgmt4/oRasmjXhs6ppa+0r9VMyPHIzi0uDLTdSkwLnjZ8zzh994Lnju6zrt44lpa3lv4Rbv+uBKs7JVVnGpKNBlnxveWlTjpwnwDBj379SS177J5Ikv1vL+zUMZ2rNdnWIOJyV/CZnZ95zF1j0HOTuI8YATnbOgjL+MZmNuAZ1aNeaXL33P2h1VLxEAjDyuE+2ap9C9XVPmrcvln99v5soh3XhvoafIWaNEw1VDunPPBX0Bakz+1ww9mrfn/1TH704amkC3p9Zk3rqaZ5jdvNv/VtHqPtF8vT6XhZs9d8F9syGXM3qnVukzY9UObn57MXef38c7h0PFRD4NhYnWgxWDBw+26enptXeUuPHinA08M8Nzu97WvQe5dFBXpizbxuO/OJ7OrZtU6V9ebnnws5X0aN+sysNHl/zje7+5dyssvO9cjmrZGIC0cVPD842Ia/meyVf+/Zj5xzM5P8CnjX/dMITC4lIu6N+RhARDTn4RQx6fHXD/0+44g5z8oqBOkKpjjFlsrR1c7x1U7EfJXxoKay0HisuqPIVcH4dKy/hwcRb3f+IZiLzz3N5ce1oabX3u104bN5WmyYlVLjUd1SKFP4/sxxl92vv9Ef96aHeuOKU7P3vhW0Ll6lO7886CLbV3lIh47lcD+cVJnpLo4Tw5yHh8FEmJ9bvZUslfJAhl5ZY12/dzXJdWVbZl7T1Is+Qklm7dy54DJWzfV8iEmesZktaWSbcM8/b7bFk2d76/jFWPXEizlCQmzs1g3rpcPvjdUO9MXRUmXDaQeyYvDyq26XedyYJNu3nos1VH9k1KSD0w5tiwjxk9MObYepfKCFXy1zV/iWmJCSZg4gfPHUtw+OG2+ZnOwzyVBvjGntiFsSd28a7fek4vbj2nV8B99ut0+O6P6XedyeKf9rJzfxHZ+wr5cHGWd5sx0LdjC/IrVbrc/OQY/rt8G7e/tzS4b1BCLhI3C+xtABVMlfxFHL2Pag54BoTrqk+H5qzfWUBSQoJfW8WtgIXFZZzfvwN9OrTgnGe+YurtZwDQw5mFbWjPtjz5yxOO9FsQl0hMiP7ztUr+Io52zVPq/CzChMsGUm4ty7buY/3OAto0PVxTxrekdpPkRO/EO76v0a55CpN+N4z+nVvWOtbx1vWncFbvVJ78ci2vfl3/8soSfbU9UR8J0X/7EXGxSwZ15bLB3fi/iwYw6+6zvHcS1cWQHm39En/Fe8axnVqS+ZfR3vYEY0hIMNw3+tjKuxCXqVxNNxqU/EVCIDkpgV7OZaNnLhvIuzeeWu99nd6rPV1aN+GZy04gIcGw6P7zuOmMHpzeq723zz+uPhnwPPzm6/PbT2fjX0bzyjWD6v36En4FtZSpjgRd9hEJsUsHdT2ir2/dNJnvxo3wrqe2SOH+Mf39+ow6vpP38lFxmeXleRv59Nbh3sHtCwd0pG+HFgzv1Z5Za3ZWKc3s6/gurejQsjGz1uysto+EVmlZaJ58PhJK/iIu9z8X9OHnJ3WmX0f/ydqn//FMAK47LY3pq3b4TegCMOaETgzt2c47wD3679+wenvwpRuk/upS1yhcdNlHxOWSEhOqJH5f3ds15aYze/Labzy3ho8+viPHpDZj3Mh+fnc2jTyuY42vs+yh83mmmkJ/UjeTfW77jRad+YvEifP7d6jxbqbbzunFtcPSaNW0Edv2FXLak3P8trdumszQnp6Cffec34dteYXceW4fnpmxzu8ZhktO7spHSwInt/P7d/BOCFPhmcsGsvinvX5F12JdB6fabTQp+YsI4Cnh3cq5VbVz6yZMvmUYs9fk8IdzjqHQKYERqJT3vaP6UV5umbJ8GxMuH8jYE7v4JX/f/kUlZfxl2hrSN+9lW14hU249ne7tmrLngKeccpumjfjPjaf6Ve+867zeDDq6Dde8sRCAC/p3YMZqd49P/PykLrV3CjMlfxEJ6JS0tt7S3C0bVz8nbrvmKTz7qxN51mfe3tn3nMUT09YyfuwAv76NGyUyfuxxVfaR4Nzf+suTuzKg8+EnsqfecTr9OrYkMcGQYDyTvL9yzSAe+e9qBnZrxXsLtrJw8x7+eskJrNqWx/XDe1BSVs61by5km1OSeUDnlqza5j+WMfmWYbRtlkzHlo1JMIasvQfp1rYpr8zL5LlZ6+t2oOrhtGPa194pzJT8RSTkjkltzuvXBl9+puIupUFHtwE8dyDlF5X4vRHMv/dc8gpLMMbw8MWeN5WmyUks3LyH8/p34PJTunn73nleb/780QoAPvnDcPYVFpOUkEB+UQkZOQVV5pvo3aGF9+t+M+xolmzZy92TlpNXWML6x0axYNNurnljIU2TE/nh3nNp6cxC5lvb6Ys7z+Dt+T/xrlOor22zZPYc8JRxuHRQV79LY/07Vz9GEykq7CYiDUJu/iHvzG9Hqqzc8u8fNnPVqd1JSUqs1z4ycvKZszaHm888BvBcsmrcyH9fuwsOYYEWjZNISUqkvNxSWm4pt5bGjRLZkVfE/Mzd/PykLny5cgcndG0VsDx5Xaiqp4hIHApV8tetniIicUjJX0QkDgWV/I0xI40x64wxGcaYcQG2pxhjPnC2LzDGpIU8UhERCZlak78xJhGYCIwC+gNXGmP6V+r2W2CvtbYX8BzwVKgDFRGR0AnmzH8IkGGtzbTWFgPvA2Mr9RkL/MtZ/hA41/gWMxcRkQYlmOTfBdjqs57ltAXsY60tBfKAdpV3ZIy52RiTboxJz83NrV/EIiJyxCI64GutfdVaO9haOzg1NTWSLy0iIj6CSf7ZQDef9a5OW8A+xpgkoBWwOxQBiohI6AVT3mER0NsY0wNPkr8CuKpSnynAtcAPwKXAHFvL02OLFy/eZYz5qe4hA9Ae2FXPr40WxRw5boxbMUdGLMR8dHUd66LW5G+tLTXG3AZMBxKBN621q4wx44F0a+0U4A3gbWNMBrAHzxtEbfut93UfY0x6KJ5wiyTFHDlujFsxR4ZiPiyowm7W2mnAtEptD/ksFwGXhTY0EREJFz3hKyISh9ya/F+NdgD1oJgjx41xK+bIUMyOqFX1FBGR6HHrmb+IiBwBJX8RkTjkuuRfW4XRCMfSzRgz1xiz2hizyhhzp9Pe1hgz0xizwfm/jdNujDHPO7H/aIw52Wdf1zr9Nxhjrg1z3InGmKXGmM+d9R5ONdYMpzprstNebbVWY8y9Tvs6Y8yF4YzXeb3WxpgPjTFrjTFrjDHDXHCc/+j8Xqw0xrxnjGnc0I61MeZNY0yOMWalT1vIjqsxZpAxZoXzNc8bc+Q1v6qJ+Wnnd+NHY8wnxpjWPtsCHr/qckl1P6NQx+yz7R5jjDXGtHfWI3OcrbWu+YfnOYONQE8gGVgO9I9iPJ2Ak53lFsB6PJVP/wqMc9rHAU85y6OBLwADDAUWOO1tgUzn/zbOcpswxn038C7wubM+CbjCWX4Z+L2z/AfgZWf5CuADZ7m/c+xTgB7OzyQxzMf6X8CNznIy0LohH2c89a42AU18jvF1De1YA2cCJwMrfdpCdlyBhU5f43ztqDDFfAGQ5Cw/5RNzwONHDbmkup9RqGN22rvheYbqJ6B9JI9z2P5Yw/QHNQyY7rN+L3BvtOPyiecz4HxgHdDJaesErHOWXwGu9Om/ztl+JfCKT7tfvxDH2BWYDYwAPnd+WXb5/OF4j7HzSznMWU5y+pnKx923X5hiboUnkZpK7Q35OFcUO2zrHLvPgQsb4rEG0vBPpCE5rs62tT7tfv1CGXOlbb8A3nGWAx4/qsklNf09hCNmPFWQBwKbOZz8I3Kc3XbZJ5gKo1HhfEw/CVgAdLDWbnc27QA6OMvVxR/J7+tvwJ+Acme9HbDPeqqxVn7t6qq1Rvrn0APIBd4ynstVrxtjmtGAj7O1Nht4BtgCbMdz7BbT8I81hO64dnGWK7eH2w14zn6pJbZA7TX9PYSUMWYskG2tXV5pU0SOs9uSf4NkjGkOfATcZa3d77vNet6KG8T9tMaYnwE51trF0Y6ljpLwfGT+h7X2JOAAnssRXg3pOAM418nH4nnj6gw0A0ZGNah6aGjHtTbGmPuBUuCdaMdSE2NMU+A+4KHa+oaL25J/MBVGI8oY0whP4n/HWvux07zTGNPJ2d4JyHHaq4s/Ut/XcOBiY8xmPJPyjAD+DrQ2nmqslV+7umqtkf45ZAFZ1toFzvqHeN4MGupxBjgP2GStzbXWlgAf4zn+Df1YQ+iOa7azXLk9LIwx1wE/A6523rSoJbZA7bup/mcUSsfgOTFY7vw9dgWWGGM61iPm+h3nUF47DPc/PGeAmc5BqxikGRDFeAzwb+Bvldqfxn/A7K/O8hj8B3IWOu1t8VzTbuP82wS0DXPsZ3N4wHcy/gNcf3CWb8V/EHKSszwA/0G0TMI/4PsN0NdZftg5xg32OAOnAquApk4c/wJub4jHmqrX/EN2XKk6EDk6TDGPBFYDqZX6BTx+1JBLqvsZhTrmSts2c/iaf0SOc9j+WMP1D89I+Ho8I/X3RzmW0/F8JP4RWOb8G43nuuFsYAMwy+cHZPDMh7wRWAEM9tnXDUCG8+/6CMR+NoeTf0/nlyfD+cVPcdobO+sZzvaePl9/v/N9rCMEd3AEEe+JQLpzrD91fvkb9HEGHgHWAiuBt50E1KCONfAenjGJEjyfsH4byuMKDHa+/43Ai1QatA9hzBl4rodX/B2+XNvxo5pcUt3PKNQxV9q+mcPJPyLHWeUdRETikNuu+YuISAgo+YuIxCElfxGROKTkLyISh5T8RUTikJK/iEgcUvIXEYlD/w/oPLf/C7zPWgAAAABJRU5ErkJggg==\n",
|
||
"text/plain": [
|
||
"<Figure size 432x288 with 1 Axes>"
|
||
]
|
||
},
|
||
"metadata": {
|
||
"needs_background": "light"
|
||
},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"plt.plot(loss_value)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 5,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"ename": "NameError",
|
||
"evalue": "name 'net' is not defined",
|
||
"output_type": "error",
|
||
"traceback": [
|
||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
|
||
"\u001b[0;32m<ipython-input-5-47d0cd6ed22d>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mfile_path\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"./net_trainned6000_MB_102\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0msave_net\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfile_path\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mnet\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
|
||
"\u001b[0;31mNameError\u001b[0m: name 'net' is not defined"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"file_path=\"./net_trainned6000_MB_102\"\n",
|
||
"save_net(file_path,net)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": []
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "Python 3",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.8.5"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 4
|
||
}
|