589 lines
40 KiB
Text
589 lines
40 KiB
Text
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 1,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"#Tous les codes sont basés sur l'environnement suivant\n",
|
||
"#python 3.7\n",
|
||
"#opencv 3.1.0\n",
|
||
"#pytorch 1.4.0\n",
|
||
"\n",
|
||
"import torch\n",
|
||
"from torch.autograd import Variable\n",
|
||
"import torch.nn as nn\n",
|
||
"import torch.nn.functional as F\n",
|
||
"import cv2\n",
|
||
"import matplotlib.pyplot as plt\n",
|
||
"import numpy as np\n",
|
||
"import random\n",
|
||
"import math\n",
|
||
"import pickle\n",
|
||
"import random\n",
|
||
"from PIL import Image\n",
|
||
"import sys\n",
|
||
"from glob import glob\n",
|
||
"from IPython.display import clear_output\n",
|
||
"from datetime import datetime\n",
|
||
"from time import time"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 2,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"#Les fonctions dans ce bloc ne sont pas utilisées par le réseau, mais certaines fonctions d'outils\n",
|
||
"\n",
|
||
"\n",
|
||
"def tensor_imshow(im_tensor,cannel):\n",
|
||
" b,c,h,w=im_tensor.shape\n",
|
||
" if c==1:\n",
|
||
" plt.imshow(im_tensor.squeeze().detach().numpy())\n",
|
||
" else:\n",
|
||
" plt.imshow(im_tensor.squeeze().detach().numpy()[cannel,:])\n",
|
||
"\n",
|
||
"# Obtenez des données d'entraînement\n",
|
||
"# frag,vt=get_training_fragment(frag_size,image)\n",
|
||
"# frag est un patch carrée de taille (frag_size*frag_size) a partir du image(Son emplacement est aléatoire)\n",
|
||
"# vt est la vérité terrain de la forme Dirac.\n",
|
||
"def get_training_fragment(frag_size,im):\n",
|
||
" h,w,c=im.shape\n",
|
||
" n=random.randint(0,int(h/frag_size)-1)\n",
|
||
" m=random.randint(0,int(w/frag_size)-1) \n",
|
||
" shape=frag_size/4\n",
|
||
" vt_h=math.ceil((h+1)/shape)\n",
|
||
" vt_w=math.ceil((w+1)/shape)\n",
|
||
" vt=np.zeros([vt_h,vt_w])\n",
|
||
" vt_h_po=round((vt_h-1)*(n*frag_size/(h-1)+(n+1)*frag_size/(h-1))/2)\n",
|
||
" vt_w_po=round((vt_w-1)*(m*frag_size/(w-1)+(m+1)*frag_size/(w-1))/2)\n",
|
||
" vt[vt_h_po,vt_w_po]=1\n",
|
||
" vt = np.float32(vt)\n",
|
||
" vt=torch.from_numpy(vt.reshape(1,1,vt_h,vt_w))\n",
|
||
" \n",
|
||
" return im[n*frag_size:(n+1)*frag_size,m*frag_size:(m+1)*frag_size,:],vt\n",
|
||
"\n",
|
||
"def load_training_fragment(fragment_path,vt_path):\n",
|
||
" # Load fragment\n",
|
||
" frag = cv2.imread(fragment_path)\n",
|
||
" \n",
|
||
" # Load vt data\n",
|
||
" with open(vt_path,'r') as f:\n",
|
||
" data_vt_raw = f.readlines()\n",
|
||
" data_vt = [int(d.rstrip('\\r\\n')) for d in data_vt_raw]\n",
|
||
" \n",
|
||
" # Construct vt\n",
|
||
" vt = np.zeros((int(data_vt[0]/4)+1,int(data_vt[1]/4)+1))\n",
|
||
" vt[int(data_vt[2]/4),int(data_vt[3]/4)] = 1\n",
|
||
" vt = np.float32(vt)\n",
|
||
" vt = torch.from_numpy(vt.reshape(1,1,int(data_vt[0]/4)+1,int(data_vt[1]/4)+1))\n",
|
||
" \n",
|
||
" return(frag,vt)\n",
|
||
"\n",
|
||
"\n",
|
||
"# Cette fonction convertit l'image en variable de type Tensor.\n",
|
||
"# Toutes les données de calcul du réseau sont de type Tensor\n",
|
||
"# Img.shape=[Height,Width,Channel]\n",
|
||
"# Tensor.shape=[Batch,Channel,Height,Width]\n",
|
||
"def img2tensor(im):\n",
|
||
" im=np.array(im,dtype=\"float32\")\n",
|
||
" tensor_cv = torch.from_numpy(np.transpose(im, (2, 0, 1)))\n",
|
||
" im_tensor=tensor_cv.unsqueeze(0)\n",
|
||
" return im_tensor\n",
|
||
"\n",
|
||
"# Trouvez les coordonnées de la valeur maximale dans une carte de corrélation\n",
|
||
"# x,y=show_coordonnee(carte de corrélation)\n",
|
||
"def show_coordonnee(position_pred):\n",
|
||
" map_corre=position_pred.squeeze().detach().numpy()\n",
|
||
" h,w=map_corre.shape\n",
|
||
" max_value=map_corre.max()\n",
|
||
" coordonnee=np.where(map_corre==max_value)\n",
|
||
" return coordonnee[0].mean()/h,coordonnee[1].mean()/w\n",
|
||
"\n",
|
||
"# Filtrer les patchs en fonction du nombre de pixels noirs dans le patch\n",
|
||
"# Si seuls les pixels non noirs sont plus grands qu'une certaine proportion(seuillage), revenez à True, sinon False\n",
|
||
"def test_fragment32_32(frag,seuillage):\n",
|
||
" a=frag[:,:,0]+frag[:,:,1]+frag[:,:,2]\n",
|
||
" mask = (a == 0)\n",
|
||
" arr_new = a[mask]\n",
|
||
" if arr_new.size/a.size<=(1-seuillage):\n",
|
||
" return True\n",
|
||
" else:\n",
|
||
" return False\n",
|
||
" \n",
|
||
"# Ces deux fonctions permettent de sauvegarder le réseau dans un fichier\n",
|
||
"# ou de load le réseau stocké à partir d'un fichier\n",
|
||
"def save_net(file_path,net):\n",
|
||
" pkl_file = open(file_path, 'wb')\n",
|
||
" pickle.dump(net,pkl_file)\n",
|
||
" pkl_file.close()\n",
|
||
"def load_net(file_path): \n",
|
||
" pkl_file = open(file_path, 'rb')\n",
|
||
" net= pickle.load(pkl_file)\n",
|
||
" pkl_file.close()\n",
|
||
" return net"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 3,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Créer un poids de type DeepMatch comme valeur initiale de Conv1 (non obligatoire)\n",
|
||
"def ini():\n",
|
||
" kernel=torch.zeros([8,3,3,3])\n",
|
||
" array_0=np.array([[1,2,1],[0,0,0],[-1,-2,-1]],dtype='float32')\n",
|
||
" array_1=np.array([[2,1,0],[1,0,-1],[0,-1,-2]],dtype='float32')\n",
|
||
" array_2=np.array([[1,0,-1],[2,0,-2],[1,0,-1]],dtype='float32')\n",
|
||
" array_3=np.array([[0,-1,-2],[1,0,-1],[2,1,0]],dtype='float32')\n",
|
||
" array_4=np.array([[-1,-2,-1],[0,0,0],[1,2,1]],dtype='float32')\n",
|
||
" array_5=np.array([[-2,-1,0],[-1,0,1],[0,1,2]],dtype='float32')\n",
|
||
" array_6=np.array([[-1,0,1],[-2,0,2],[-1,0,1]],dtype='float32')\n",
|
||
" array_7=np.array([[0,1,2],[-1,0,1],[-2,-1,0]],dtype='float32')\n",
|
||
" for i in range(3):\n",
|
||
" kernel[0,i,:]=torch.from_numpy(array_0)\n",
|
||
" kernel[1,i,:]=torch.from_numpy(array_1)\n",
|
||
" kernel[2,i,:]=torch.from_numpy(array_2)\n",
|
||
" kernel[3,i,:]=torch.from_numpy(array_3)\n",
|
||
" kernel[4,i,:]=torch.from_numpy(array_4)\n",
|
||
" kernel[5,i,:]=torch.from_numpy(array_5)\n",
|
||
" kernel[6,i,:]=torch.from_numpy(array_6)\n",
|
||
" kernel[7,i,:]=torch.from_numpy(array_7)\n",
|
||
" return torch.nn.Parameter(kernel,requires_grad=True) \n",
|
||
"\n",
|
||
"# Calculer le poids initial de la couche convolutive add\n",
|
||
"# n, m signifie qu'il y a n * m sous-patches dans le patch d'entrée\n",
|
||
"# Par exemple, le patch d'entrée est 16 * 16, pour les patchs 4 * 4 de la première couche, n = 4, m = 4\n",
|
||
"# pour les patchs 8 * 8 de la deuxième couche, n = 2, m = 2\n",
|
||
"def kernel_add_ini(n,m):\n",
|
||
" input_canal=int(n*m)\n",
|
||
" output_canal=int(n/2)*int(m/2)\n",
|
||
" for i in range(int(n/2)):\n",
|
||
" for j in range(int(m/2)):\n",
|
||
" kernel_add=np.zeros([1,input_canal],dtype='float32')\n",
|
||
" kernel_add[0,i*2*m+j*2]=1\n",
|
||
" kernel_add[0,i*2*m+j*2+1]=1\n",
|
||
" kernel_add[0,(i*2+1)*m+j*2]=1\n",
|
||
" kernel_add[0,(i*2+1)*m+j*2+1]=1\n",
|
||
" if i==0 and j==0:\n",
|
||
" add=torch.from_numpy(kernel_add.reshape(1,input_canal,1,1))\n",
|
||
" else:\n",
|
||
" add_=torch.from_numpy(kernel_add.reshape(1,input_canal,1,1))\n",
|
||
" add=torch.cat((add,add_),0)\n",
|
||
" return torch.nn.Parameter(add,requires_grad=False) \n",
|
||
"\n",
|
||
"# Calculer le poids initial de la couche convolutive shift\n",
|
||
"# shift+add Peut réaliser l'étape de l'agrégation\n",
|
||
"# Voir ci-dessus pour les paramètres n et m. \n",
|
||
"# Pour des étapes plus détaillées, veuillez consulter mon rapport de stage\n",
|
||
"def kernel_shift_ini(n,m):\n",
|
||
" input_canal=int(n*m)\n",
|
||
" output_canal=int(n*m)\n",
|
||
" \n",
|
||
" kernel_shift=torch.zeros([output_canal,input_canal,3,3])\n",
|
||
" \n",
|
||
" array_0=np.array([[1,0,0],[0,0,0],[0,0,0]],dtype='float32')\n",
|
||
" array_1=np.array([[0,0,1],[0,0,0],[0,0,0]],dtype='float32')\n",
|
||
" array_2=np.array([[0,0,0],[0,0,0],[1,0,0]],dtype='float32')\n",
|
||
" array_3=np.array([[0,0,0],[0,0,0],[0,0,1]],dtype='float32')\n",
|
||
" \n",
|
||
" kernel_shift_0=torch.from_numpy(array_0)\n",
|
||
" kernel_shift_1=torch.from_numpy(array_1)\n",
|
||
" kernel_shift_2=torch.from_numpy(array_2)\n",
|
||
" kernel_shift_3=torch.from_numpy(array_3)\n",
|
||
" \n",
|
||
" \n",
|
||
" for i in range(n):\n",
|
||
" for j in range(m):\n",
|
||
" if i==0 and j==0:\n",
|
||
" kernel_shift[0,0,:]=kernel_shift_0\n",
|
||
" else:\n",
|
||
" if i%2==0 and j%2==0:\n",
|
||
" kernel_shift[i*m+j,i*m+j,:]=kernel_shift_0\n",
|
||
" if i%2==0 and j%2==1:\n",
|
||
" kernel_shift[i*m+j,i*m+j,:]=kernel_shift_1\n",
|
||
" if i%2==1 and j%2==0:\n",
|
||
" kernel_shift[i*m+j,i*m+j,:]=kernel_shift_2\n",
|
||
" if i%2==1 and j%2==1:\n",
|
||
" kernel_shift[i*m+j,i*m+j,:]=kernel_shift_3\n",
|
||
" \n",
|
||
" return torch.nn.Parameter(kernel_shift,requires_grad=False) \n",
|
||
"\n",
|
||
"# Trouvez le petit patch(4 * 4) dans la n ème ligne et la m ème colonne du patch d'entrée\n",
|
||
"# Ceci est utilisé pour calculer la convolution et obtenir la carte de corrélation\n",
|
||
"def get_patch(fragment,psize,n,m):\n",
|
||
" return fragment[:,:,n*psize:(n+1)*psize,m*psize:(m+1)*psize]\n",
|
||
"###################################################################################################################\n",
|
||
"class Net(nn.Module):\n",
|
||
" def __init__(self,frag_size,psize):\n",
|
||
" super(Net, self).__init__()\n",
|
||
" \n",
|
||
" h_fr=frag_size\n",
|
||
" w_fr=frag_size\n",
|
||
" \n",
|
||
" n=int(h_fr/psize) # n*m patches dans le patch d'entrée\n",
|
||
" m=int(w_fr/psize)\n",
|
||
" \n",
|
||
" self.conv1 = nn.Conv2d(3,8,kernel_size=3,stride=1,padding=1)\n",
|
||
" # Si vous souhaitez initialiser Conv1 avec les poids de DeepMatch, exécutez la ligne suivante\n",
|
||
" #self.conv1.weight=ini()\n",
|
||
" self.Relu = nn.ReLU(inplace=True)\n",
|
||
" self.maxpooling=nn.MaxPool2d(3,stride=2, padding=1)\n",
|
||
" \n",
|
||
" self.shift1=nn.Conv2d(n*m,n*m,kernel_size=3,stride=1,padding=1)\n",
|
||
" self.shift1.weight=kernel_shift_ini(n,m)\n",
|
||
" self.add1 = nn.Conv2d(n*m,int(n/2)*int(m/2),kernel_size=1,stride=1,padding=0)\n",
|
||
" self.add1.weight=kernel_add_ini(n,m)\n",
|
||
" \n",
|
||
" n=int(n/2)\n",
|
||
" m=int(m/2)\n",
|
||
" if n>=2 and m>=2:# Si n=m=1,Notre réseau n'a plus besoin de plus de couches pour agréger les cartes de corrélation\n",
|
||
" self.shift2=nn.Conv2d(n*m,n*m,kernel_size=3,stride=1,padding=1)\n",
|
||
" self.shift2.weight=kernel_shift_ini(n,m)\n",
|
||
" self.add2 = nn.Conv2d(n*m,int(n/2)*int(m/2),kernel_size=1,stride=1,padding=0)\n",
|
||
" self.add2.weight=kernel_add_ini(n,m)\n",
|
||
" \n",
|
||
" n=int(n/2)\n",
|
||
" m=int(m/2)\n",
|
||
" if n>=2 and m>=2:\n",
|
||
" self.shift3=nn.Conv2d(n*m,n*m,kernel_size=3,stride=1,padding=1)\n",
|
||
" self.shift3.weight=kernel_shift_ini(n,m)\n",
|
||
" self.add3 = nn.Conv2d(n*m,int(n/2)*int(m/2),kernel_size=1,stride=1,padding=0)\n",
|
||
" self.add3.weight=kernel_add_ini(n,m)\n",
|
||
" \n",
|
||
" def get_descripteur(self,img,using_cuda):\n",
|
||
" # Utilisez Conv1 pour calculer le descripteur,\n",
|
||
" descripteur_img=self.Relu(self.conv1(img))\n",
|
||
" b,c,h,w=descripteur_img.shape\n",
|
||
" couche_constante=0.5*torch.ones([1,1,h,w])\n",
|
||
" if using_cuda:\n",
|
||
" couche_constante=couche_constante.cuda()\n",
|
||
" # Ajouter une couche constante pour éviter la division par 0 lors de la normalisation\n",
|
||
" descripteur_img=torch.cat((descripteur_img,couche_constante),1)\n",
|
||
" # la normalisation\n",
|
||
" descripteur_img_norm=descripteur_img/torch.norm(descripteur_img,dim=1)\n",
|
||
" return descripteur_img_norm\n",
|
||
" \n",
|
||
" def forward(self,img,frag,using_cuda):\n",
|
||
" psize=4\n",
|
||
" # Utilisez Conv1 pour calculer le descripteur,\n",
|
||
" descripteur_input1=self.get_descripteur(img,using_cuda)\n",
|
||
" descripteur_input2=self.get_descripteur(frag,using_cuda)\n",
|
||
" \n",
|
||
" b,c,h,w=frag.shape\n",
|
||
" n=int(h/psize)\n",
|
||
" m=int(w/psize)\n",
|
||
" \n",
|
||
" #######################################\n",
|
||
" # Calculer la carte de corrélation par convolution pour les n*m patchs plus petit.\n",
|
||
" for i in range(n):\n",
|
||
" for j in range(m):\n",
|
||
" if i==0 and j==0:\n",
|
||
" map_corre=F.conv2d(descripteur_input1,get_patch(descripteur_input2,psize,i,j),padding=2)\n",
|
||
" else:\n",
|
||
" a=F.conv2d(descripteur_input1,get_patch(descripteur_input2,psize,i,j),padding=2)\n",
|
||
" map_corre=torch.cat((map_corre,a),1)\n",
|
||
" ########################################\n",
|
||
" # Étape de polymérisation\n",
|
||
" map_corre=self.maxpooling(map_corre)\n",
|
||
" map_corre=self.shift1(map_corre)\n",
|
||
" map_corre=self.add1(map_corre)\n",
|
||
" \n",
|
||
" #########################################\n",
|
||
" # Répétez l'étape d'agrégation jusqu'à obtenir le graphique de corrélation du patch d'entrée\n",
|
||
" n=int(n/2)\n",
|
||
" m=int(m/2)\n",
|
||
" if n>=2 and m>=2:\n",
|
||
" map_corre=self.maxpooling(map_corre)\n",
|
||
" map_corre=self.shift2(map_corre)\n",
|
||
" map_corre=self.add2(map_corre)\n",
|
||
" \n",
|
||
" \n",
|
||
" n=int(n/2)\n",
|
||
" m=int(m/2)\n",
|
||
" if n>=2 and m>=2:\n",
|
||
" map_corre=self.maxpooling(map_corre)\n",
|
||
" map_corre=self.shift3(map_corre)\n",
|
||
" map_corre=self.add3(map_corre)\n",
|
||
" \n",
|
||
" \n",
|
||
" b,c,h,w=map_corre.shape\n",
|
||
" # Normalisation de la division par maximum\n",
|
||
" map_corre=map_corre/(map_corre.max())\n",
|
||
" # Normalisation SoftMax\n",
|
||
" #map_corre=(F.softmax(map_corre.reshape(1,1,h*w,1),dim=2)).reshape(b,c,h,w)\n",
|
||
" return map_corre"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 4,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"def run_net(net,img,frag,frag_size,using_cuda):\n",
|
||
" h,w,c=frag.shape\n",
|
||
" n=int(h/frag_size)\n",
|
||
" m=int(w/frag_size)\n",
|
||
" frag_list=[]\n",
|
||
" #####################################\n",
|
||
" # Obtenez des patchs carrés des fragments et mettez-les dans la frag_list\n",
|
||
" for i in range(n):\n",
|
||
" for j in range(m):\n",
|
||
" frag_32=frag[i*frag_size:(i+1)*frag_size,j*frag_size:(j+1)*frag_size]\n",
|
||
" if test_fragment32_32(frag_32,0.6):\n",
|
||
" frag_list.append(frag_32)\n",
|
||
" img_tensor=img2tensor(img)\n",
|
||
" ######################################\n",
|
||
" if using_cuda:\n",
|
||
" img_tensor=img_tensor.cuda()\n",
|
||
" \n",
|
||
" coordonnee_list=[]\n",
|
||
" #######################################\n",
|
||
" # Utilisez le réseau pour calculer les positions de tous les patchs dans frag_list[]\n",
|
||
" # Mettez le résultat du calcul dans coordonnee_list[]\n",
|
||
" for i in range(len(frag_list)):\n",
|
||
" frag_tensor=img2tensor(frag_list[i])\n",
|
||
" if using_cuda:\n",
|
||
" frag_tensor=frag_tensor.cuda()\n",
|
||
" res=net.forward(img_tensor,frag_tensor,using_cuda)\n",
|
||
" if using_cuda:\n",
|
||
" res=res.cpu()\n",
|
||
" po_h,po_w=show_coordonnee(res)\n",
|
||
" coordonnee_list.append([po_h,po_w])\n",
|
||
" h_img,w_img,c=img.shape\n",
|
||
" position=[]\n",
|
||
" for i in range(len(coordonnee_list)):\n",
|
||
" x=int(round(h_img*coordonnee_list[i][0]))\n",
|
||
" y=int(round(w_img*coordonnee_list[i][1]))\n",
|
||
" position.append([x,y])\n",
|
||
" return position"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Fresque 0, fragment 2824/3000 (94.1%)\n",
|
||
"Temps par fragment: 0.759\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"if __name__=='__main__':\n",
|
||
" \n",
|
||
" # La taille du patch d'entrée est de 16*16\n",
|
||
" frag_size=16\n",
|
||
" # La taille du plus petit patch dans réseau est de 4 *4 fixée\n",
|
||
" psize=4\n",
|
||
" using_cuda=True\n",
|
||
" \n",
|
||
" # Variable des données\n",
|
||
" base_dir = './training_data_small/'\n",
|
||
" fresque_filename = base_dir+'fresque_small{}.ppm'\n",
|
||
" fresque_filename_wild = base_dir+'fresque_small*.ppm'\n",
|
||
" fragment_filename = base_dir+'fragments/fresque{}/frag_dev_{:05}.ppm'\n",
|
||
" fragments_filename_wild = base_dir+'fragments/fresque{}/frag_dev_*.ppm'\n",
|
||
" vt_filename = base_dir+'fragments/fresque{}/vt/frag_dev_{:05}_vt.txt'\n",
|
||
" fragment_id_offset = 200\n",
|
||
" \n",
|
||
" \n",
|
||
" net=Net(frag_size,psize)\n",
|
||
" \n",
|
||
" # Pour chaque fresque, le nombre d'itérations est de 1000\n",
|
||
" itera=1000\n",
|
||
" \n",
|
||
" if using_cuda:\n",
|
||
" net=net.cuda()\n",
|
||
" \n",
|
||
" # Choisissez l'optimiseur et la fonction de coût\n",
|
||
" optimizer = torch.optim.Adam(net.parameters())\n",
|
||
" loss_func = torch.nn.MSELoss()\n",
|
||
" \n",
|
||
" # Dans le processus d'apprentissage du réseau,le changement d'erreur est placé dans loss_value=[] \n",
|
||
" # et le changement de Conv1 poids est placé dans para_value[]\n",
|
||
" loss_value=[]\n",
|
||
" w_values=[]\n",
|
||
" ####################################################training_net\n",
|
||
" \n",
|
||
" # Detection des fresques\n",
|
||
" fresques_paths = glob(fresque_filename_wild) \n",
|
||
" N_fresque = len(fresques_paths)\n",
|
||
" \n",
|
||
" time_old = time()\n",
|
||
" \n",
|
||
" # Iteration sur les fresques trouvées\n",
|
||
" for fresque_id,fresque_path in enumerate(fresques_paths):\n",
|
||
" # Charge la fresque\n",
|
||
" fresque=cv2.imread(fresque_path)\n",
|
||
" h,w,c=fresque.shape\n",
|
||
" fresque_tensor=img2tensor(fresque)\n",
|
||
" \n",
|
||
" # Si GPU, conversion de la fresque\n",
|
||
" if using_cuda:\n",
|
||
" fresque_tensor=fresque_tensor.cuda()\n",
|
||
" \n",
|
||
" # Detection des fragments d'entrainement\n",
|
||
" fragments_paths = glob(fragments_filename_wild.format(fresque_id))\n",
|
||
" N_fragments = len(fragments_paths)\n",
|
||
" for fragment_id,fragment_path in enumerate(fragments_paths):\n",
|
||
" clear_output(wait=True)\n",
|
||
" print(\"Fresque {}, fragment {}/{} ({:.3}%)\".format(fresque_id,fragment_id,N_fragments,(fragment_id/N_fragments)*100))\n",
|
||
" print(\"Temps par fragment: {:.3}\".format(time()-time_old))\n",
|
||
" time_old = time()\n",
|
||
" # Tous les 100 cycles, enregistrez le changement de poids\n",
|
||
" if fragment_id%50==0:\n",
|
||
" w_values.append(net.conv1.weight.data.cpu().numpy())\n",
|
||
" \n",
|
||
" # Chargement du fragment et de la vt\n",
|
||
" frag,vt=load_training_fragment(fragment_path,vt_filename.format(fresque_id,fragment_id+fragment_id_offset))\n",
|
||
" \n",
|
||
" # si GPU, conversion des objects\n",
|
||
" frag_tensor=img2tensor(frag)\n",
|
||
" if using_cuda:\n",
|
||
" vt=vt.cuda()\n",
|
||
" frag_tensor=frag_tensor.cuda()\n",
|
||
" \n",
|
||
" frag_pred=net.forward(fresque_tensor,frag_tensor,using_cuda)\n",
|
||
" b,c,h,w=vt.shape\n",
|
||
" # Utilisez la fonction de coût pour calculer l'erreur\n",
|
||
" err_=loss_func(vt,frag_pred)\n",
|
||
" # Utilisez l'optimiseur pour ajuster le poids de Conv1\n",
|
||
" optimizer.zero_grad()\n",
|
||
" err_.backward(retain_graph=True)\n",
|
||
" optimizer.step()\n",
|
||
" \n",
|
||
" loss_value.append(err_.tolist())\n",
|
||
" \n",
|
||
" del frag_tensor,frag_pred,err_,vt\n",
|
||
" torch.cuda.empty_cache()\n",
|
||
" \n",
|
||
" # Sauvegarder le réseau\n",
|
||
" save_dir = './trained_net/'\n",
|
||
" extension = 'from-random_full-dataset-small'\n",
|
||
" net_filename = save_dir + \"net_trainned_{}_{}\".format(extension,datetime.now().strftime(\"%m-%d_%H-%M\"))\n",
|
||
" save_net(net_filename,net)\n",
|
||
" \n",
|
||
" # Sauvegarder les poids\n",
|
||
" poids_filename = save_dir + \"save_weights_{}_{}\".format(extension,datetime.now().strftime(\"%m-%d_%H-%M\"))\n",
|
||
" with open(poids_filename,'wb') as f:\n",
|
||
" pickle.dump(w_values,f)\n",
|
||
" \n",
|
||
" print(\"Net sauvegardés dans {}\".format(net_filename))\n",
|
||
" print(\"Poids sauvegardés dans {}\".format(poids_filename))\n",
|
||
" "
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 21,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Poids pickled.\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"with open('./save_weights','wb') as f:\n",
|
||
" pickle.dump(w_values,f)\n",
|
||
"print(\"Poids pickled.\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 16,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"(3, 3, 3)\n",
|
||
"(8, 3, 3, 3)\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"w = w_values[0]\n",
|
||
"print(w.shape)\n",
|
||
"print(net.conv1.weight.data.cpu().numpy().shape)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 6,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"[<matplotlib.lines.Line2D at 0x7f9cc2acab50>]"
|
||
]
|
||
},
|
||
"execution_count": 6,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
},
|
||
{
|
||
"data": {
|
||
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8li6FKAAAgAElEQVR4nO3dd5wTdfoH8M+zjaXXpZelwwrSlmIBUZSuWM92FjyP405PTz1+4CnoKXJYDhsox6l4eipy6iknVaQjZZeyCyy7bAWWBXZhK23r9/dHJtlJMkkmyUwmkzzv14sXm8lk5skkeeY73/kWEkKAMcaY+UUYHQBjjDFtcEJnjLEQwQmdMcZCBCd0xhgLEZzQGWMsREQZteM2bdqI+Ph4o3bPGGOmtG/fvnNCiDil5wxL6PHx8UhOTjZq94wxZkpEdNzVc1zlwhhjIYITOmOMhQhO6IwxFiI4oTPGWIjghM4YYyGCEzpjjIUITuiMMRYiTJfQM85UYNGGDJy7UGl0KIwxFlRMl9CzCi/gvU1ZKL5YZXQojDEWVEyX0Iks/9fxxByMMWbHdAk9QkronM8ZY8ye6RI6YMnoXEJnjDF7pkvoXEIPTgWll7E5vdDoMBgLa6ZL6CRVonNCDy6T3t2O6Z8mGR0GY2HNdAndVkIHZ/RgUna52ugQGAt7pkvo9a1cjI2DMcaCjQkTurXKhTM6Y4zJmS+hS/9zCZ0xxuyZLqFHWOtcuA6dMcbsmC6hcx06Y4wpM11Cj+Bmi4wxpsh0Cb2+Dp0zOmOMyZkvoXMJnTHGFJkwoVv+52aLjDFmz3QJ3VaHbnAcjDEWbEyX0Hk8dMYYU2a6hM6jLTLGmDLTJXQeD50xxpSZLqHXj7bIGGNMznQJnQfnYowxZaZL6FyHzhhjykyX0MlWh25wIIwxFmRUJXQimkhEGUSURURzFJ5vTkT/I6IUIjpCRNO1D9W6L8v/XOXCGGP2PCZ0IooEsATAJAAJAO4nogSH1Z4AkCaEGARgLIC/E1GMxrFK8Vj+5xI6Y4zZU1NCHwEgSwiRI4SoArACwDSHdQSApmS5Y9kEQDGAGk0jlfB46IwxpkxNQu8E4KTscb60TG4xgP4ACgAcAvC0EKLOcUNENIOIkokouaioyKeAuYTOGGPK1CR0UljmmE4nADgIoCOAwQAWE1EzpxcJsUwIkSiESIyLi/M6WIDHQ2eMMVfUJPR8AF1kjzvDUhKXmw7gO2GRBSAXQD9tQrTH46EzxpgyNQk9CUBvIuou3ei8D8Aqh3VOABgHAETUDkBfADlaBmpFPNoiY4wpivK0ghCihoieBLAeQCSAT4QQR4hopvT8UgCvAviUiA7BUoieLYQ4p0fA3GyRMcaUeUzoACCEWANgjcOypbK/CwCM1zY0ZVyHzhhjykzYU9SiqKLS0DgYYyzYmC6hZxZeAAC8tuaowZEwxlhwMV1Cr6l1at7OGGMMJkzojDHGlJkuobdp2gAAEBttutDDwrJt2UaHwFjYMl1WjG/dGAAwe6Iu/ZaYnxasSTc6BMbClukSurUdelSE0ogEjDEWvkyX0K24GTpjjNkzXULncjljjCkzXUJnjDGmzLQJnbv+M8aYPdMldCKudGGMMSWmS+hWPNoiY4zZM11C5/J58KusqTU6BBaG9uScx9MrDoR1Yc90Cd0qfD+y4Fd2udroEFgYeviTvfjhYAEqa8J3vCfTJXSuQmeMMWWmS+gs+L2zMdPoEBgLS6ZN6GFcTRb0vtxzwugQGAtLpkvoxLdFGWNMkekSuhUX0BljzJ75EnoYFtCPna3gpoCMMY/Ml9Al4dLW9PyFSox/exue/+6Q0aEwZgphkhoUmS6hh1uzxYuVlpJ5Ul6xwZEwFtzCLTcoMV1CDzdCultQWF5pcCSMsWBnuoQebifhgydLASCse78xxtQxXUIPN7V1YVwhyBjzimkTerjc+AiX98kY85/pEjqPh84YY8pMl9CtBHctYowpCOfcYLqEHkrl89JLVcgqrDA6DMZCAg8LYsKEbhUKdcu3Ld6Jmxdtc7uOWWuYSi9VGR0CY2HHdAndrAlOyYniS0aHoJvX16UbHQJjYcd0Cd0qBAroqpj1SiS78KLRITAWdkyX0LWqJ8sqrOA23oyFkHC+GWpluoSuhazCCty8aBve3XjM6FA84q8oY94J55ujqhI6EU0kogwiyiKiOS7WGUtEB4noCBFt1TZMZ/5URZyVxkVJPl6iUTSMsWARziX1KE8rEFEkgCUAbgGQDyCJiFYJIdJk67QA8AGAiUKIE0TUVq+APd0UfeLL/ai4UoPPHhvhcRt1JqigDt+yBmPeCeeSuZXHhA5gBIAsIUQOABDRCgDTAKTJ1nkAwHdCiBMAIIQo1DpQR67OwqtTT3t8bYSU0bkKXT8l3GyRsYBTU+XSCcBJ2eN8aZlcHwAtiWgLEe0jooeVNkREM4gomYiSi4qKfItY8sa6DFXrpeaX4mJljd0ya0IPl0kyjJBZeMHoEBjT3Y7Mc4ifsxqFFVeMDgWAuoSudB3jmAmjAAwDMAXABABziaiP04uEWCaESBRCJMbFxXkdrLeKL1bhtsU7ce3CTbZlv/rHLqxOLQDAIxkyxvyzfGcuACD1ZJnBkVioqXLJB9BF9rgzgAKFdc4JIS4CuEhE2wAMAqB5MxJvOhZ9s89yYVF2udq2bG9uMfbmWmb/MUM+N0GIjLEgoaaEngSgNxF1J6IYAPcBWOWwzg8ARhNRFBE1AjASwFFtQ/VOWkE5Fqxx31tRqcqlprYOO7PO6RJT3jnvO9vwbR7mj4935OLrpBNGh8ECxGNCF0LUAHgSwHpYkvRKIcQRIppJRDOldY4CWAcgFcBeAB8JIQ7rEbDaO9mT39tu91gIgUUb7OvdlUrob288hgc/2mMrxWtlU/pZjH1rC35Mdby4YUw/r/6YhtnfhtcE44G8NRZsV9BqqlwghFgDYI3DsqUOj98E8KZ2oWlr7eEzeG9Tlt0yebPFXdnnkXGm3NZl/fwF+zk8hRD4cGs27hveFa0ax3i9/6OnLaMqHikox9SrO3r9esaYe0aO8xQsY0ypSuih4A9f7HdaVieA/JJL+OFgAd5cbym9T7yqveLrk/JK8Ma6DBw4UYp/Ppyoa6yMMeYL0yV0rc+EY97YrOrmaHWtZZJmxyaQnhw7W4FZ36Tiup6tfQmPMcZUC8uxXORcJfPL1bVebaeyphbrDp9xWv762nSknCzF7pzzAMw7eiJjLPiZLqHrXVV1uMDSnvTZlSnYf6IEP6WdRY1UOnfnN58mY+a/99napTqyjh/zr1/yAAA/Hz2rTcCMMSYxXUJ354KX1SFK8ksu2/6+84Nf8NvPkrFkczYOnix1+7odUlPHlcn5is+fKrVs93J1LfLOXcRv/pXsd6zM4vCpMiTMWxc0vfUYM4oJ69Dry+iVNbVoEBVpezzhbffTufnqbY2H2b1Y5f+Jh9X7eEcuLlXVYkfmOdw5tLPR4TCDBbJWM9iGDzF1CX3plhwk59W3F7eWgoPJ/hOeh+hdsjnL5RcjuL4ujAWH6to6ZJ61n2DdyJaDwdJs0XQJXX7c3t54DHcv3YVd2ecNi8cdIQRKLlV7XO/N9Rl2VT2MMffm/5iGW97eFpSFOCOZLqErOX7et/krfblc8vSSo6fLkSt18V+2LceXsBhjHiTlWa58Sy7yMM1yIZHQ53znW9fm4+cvaRyJxY1vbQEAbM/0f0yYILmSYyzocfWkCRO6q7oqX1q4eNvWHAB25ZxHndR4/fF/JeOprw54vQ1Ppfyb3tqCN9a5H1iM1QvEjal1h89g8rvbbZ89C17hXAgyXUJ3pfyy57pqrfySfR5f7DmOjUfPYlVKAbKLvJvMYer7O9w+n3PuIj7Yku1PiGFJzxtTz3x9EGmny3GlxvtCAGOBYupmi3LnHAbT0tPSrdm2ducAUKrBdGuuCplcHmTMmbuTdzj/ZkKmhP7d/lMB29cOleOlh/Ps46EmnD7L7KILKDBh6xEjqlqC7VsRMgndSGUaVPfU6lQPXHGlOug6P5hZOMwsP+7vW+2mbWSeBcv3ghO6Bh771P9u/O9vytQgEnsniy9h4Msb8Nmu45pvO9zwOZGZASf0IHHModebFvKk9vk/pYX2QGCByLXWyVCCpUcgY0pCJqGb/YfmqgToz9sKtxZ2el72VtdaDuaunODslcwYEEIJffnOPKNDCDqPfLLX6BACKhA3LosqAteaijFvhUxCZ+HL5BdnjGmGE3qQqKrxPImGr8xeHRVUwqwaK9gpVVUGslVXsN0s54QeJDILvett6o2aWmO+dWZuLqlmlipmHKVCiqtOhwERJIUmTug60jOfCSFUl+pPFOszCJkneg1+5kjrw3zgRAl6vbAW2zOLNN4yY/rihK4jPRPah1uz0efFtZoMO6AXvTpLuaJVK5e9uZZJU7QYLZPpI+gu/oIkHk7oOtJz8P1v91nmLg3kGDaMseDGCV0nl6u8H5VPaWjWIDnx++TwqTKjQ9CcXk0jn/xyP77ae0KXbYeaT3fm4khBucvnc4ouojrQ90C4Dj209Z+3zuvXfLAly+7xnpzzyCzUvgdpoKS5+dGZQSBv6v6YehrP+zhRS7j5fLf7oSymLdmJ11YfDVA09mpq63DSoHtWACf0oHLgRKnd43uX7cY/ttZPYxc/Z7XTa4KuLlEmiENzi5t5BqdiL6abs94H0Zvjd/xva9Mx+o3NOFN2JSD7d8QJPYh4k0gC1UTrclWtx3r6LB2bXOph2bZs7Dte4vL5YD5JhqvtmUUY+upP2Jh2FtlF9XMIy6vAjDwPW/e9Uxpa25uTj5Y4oQeRjUcLFUvhRvrVP3Yhcf5Gl8+fLruMmxdtDWBEzrxNwAvWpOOuD3/xuJ6h7ZqZnf3HLVev7uYiMOI8fEWaxvKKD9NZ6oETOnPrkIcbm0aVRJRonX+V6tC59M7krFU7H23PNTgSC9NNQcfs+ZJf6uoErtTUolGMfx//7Ut24uDJUs8rBohWyZanNws+1qqVYL1oqpZaqKWfMbYRA5fQTcqb77UQAos2ZKCwwnKj5pUf05Awbz0q/Zzw2FMyD1QrEccfeWVNLcou6TNpeDBdkYSjYJkZyJPLBlXBcEIPMfuOF6Ow3P4Oe0HZFby3KQvPrUwBUN8pqVLHAcGMNH15Ega9ssHv7Sidj6zHzuy2HStC/JzVSD9jrqalwVBCv1xV67EwtHCtMc0mOaGHmLs+3IUJ72xTfM4xgR8tKMc9S3/xqRNUMPsl279JKMxSCvTHuiNnAADJea5b+wQT68k1kJ9MyslSxM9ZjcOnylBZU4t3N2aisqYW/eetwy2LHH5jDmf/JIOOq6qETkQTiSiDiLKIaI6b9YYTUS0R3a1diMwdpVJkiavqBod1//q/NCTlleDASc9fvv0ngveHr3XNTiAmyvBW2aVqfLQ9x9QjWGrBXQld62S/Ic1y0tuSUYhPd+bh7Y3H8PEOy81PpwHvguHSASoSOhFFAlgCYBKABAD3E1GCi/VeB7Be6yCZM7++Pz68tuJKjR87dG/p1mzEz1mNC5X+7UPr35Ta7VXX1uk+ps7z/03F/NVHNeswY7bzgjVco5qSXrY1T1SupgyOdK6uhD4CQJYQIkcIUQVgBYBpCuv9EcC3AAo1jI95YUWQjQWiNml8vsvSlbskADccvaleUoxf4Zf77MoUJM7fqDgWj1bKLluuuqo1Hts+SAqWnkkfRrZDJzZ9h6j27/VbMgpxqUq/gpASNQm9E4CTssf50jIbIuoE4A4AS91tiIhmEFEyESUXFfFY01qb42EsEG+rEuRjUoTKpf5TKw54XMfbOvTVqQUA9G3SGCKH328/p9uXFy8F4P6PmqsCpY/n0eVJSJi3PqBjGqlJ6ErvxjH+dwDMFkK4PbpCiGVCiEQhRGJcXJzaGJkbrpK0mpKopxxRo2OJ0yhJedqP8RHIo2SaErXGXB3jQJeAfZF2OnAJXU3PknwAXWSPOwMocFgnEcAK6UzWBsBkIqoRQnyvSZTMybGz7sdPURrtMSmvBM+uPGg7Qx+Vvmg/pp4GAPx793Ek5ZUg6YWbNY3VE2vp3+zJylzhh8bJWs/vjFZH6PNdebh7WGeNtuaemoSeBKA3EXUHcArAfQAekK8ghOhu/ZuIPgXwIydzfdTWCVRc8b3TzHf7T6FprP3H/uWeE/hyj/v6d1++3Gpf4+8Nr0Cmpl5xTQK4t3p6NdszexNNV/EH4jtxQjYjmbujmJJfhqOny9G/QzPdY/JY5SKEqAHwJCytV44CWCmEOEJEM4lopt4BMnsvrzqCwa/8ZHQYqny8IxdLNmchfs5qVb1SA5Fa/K2LvrZna823qYatas3c+ddnro7x9E+T6h9oPpaP++e3HFPf/mPSu9v9jEYdVYN5CCHWAFjjsEzxBqgQ4lH/w2Ku/C/VvrbLl05BgcwJy7ZZxnO/XFWLBlGRAdxzaPHlpHG6zPUUiKF+k1XL77j8wnHd4dMabll73FPU5O74wPMwsL6y+1E4JAAthwu1VSeYoPRpdB70porktIpJFsxwzIHg6ezl6d6V0TihM0W/ZJ3TpDORq+aOSXnFtnFRbCPp+Viu0isnBUcKYYD7Kwqj5q4tLK/vTFZbJ5CsQwsqb/HwuSEg99xFdG/TWLPtlV6qwgMf7UGnFg0126Zjsr5n6S4AwF3DOgf08l9Ne/rdOf6NBaMHn25Ku3lRKFW5TH1/BzY+Oybg+z1ZUn9T9NCpMtwtfaeNxCX0EODt/IWeWpNUSYN4nSqtr4PV85K3vpULcKGyxrZ/b1/vrQuVNbjzg512U+il5pc6dV4JCjpVS5mkxsXjZ3zuQuCHNf7hoGPrbeNxQg8B/o5rrkZVTZ1PN2BtP0QVmYMADHhpPR76eI/i8/kll7zqsbp4Uybi56xGrayDlPxktjWjCPtPlGLRTxkALJfuty3e6XabtQZ1tqqvlmJKvL3iOH7+InKKLCfynVnn8GOqc3LWuhATiGnqOKEzVWb+e79iZyWPZCXLpLxi2yQbdqs4/G725BZj+U77Kb2Ona3A9a9vxj+356je9Xs/ZwEAauqUS/x7c+2rVtaqaMGwy8+hef3lTVt9d1c6OecsyexwgTH1z1rzNvne8OYW3PR3y1y4D360B09+aT8kxOFTZbYWWlq11f/vgVO2MXn0wgk9DGnxpVJbIpKvds/SXRjx2s+uryhkv5vlO/PsnrJ24tiT49+NJ3kJ/1/SoGBW6ac9Tx9W5+aNB9tYLh+5OflZx+tepWO1wdnyK3jwo92qZo9KOVmKie9sc9mVP9B1/lPf36H5PpPzSjDorxuwXhqLXg+c0EPAHo2GVLXR4bpevsnv9p9yeFZ9Kxdf6rfV/jDVbNtM9xJLLhk7Xd6HW7KxM+s8vt3veZan19YcRfqZCqTm+3jFoOMHQ6TNCeWgNO+AnjfdOaGbjFLK+3BLtqb7KLnouUSl9srfWiL+XlYSrKmtc1jH8v86WcnFegldfLEK/0k+CVeOFJThfynKpcwqaT8b0s6qC9aBrz/idzda6u71GDjKm5uivlT3azXeutc8xOptlYraAbH2BLBFk7W6TM+rDU7oJhOIERDVTNag9ktpnQ197veHbctcvQf5OlZ//Go/Zn2TirzzFxVfM++HIx5jeGfjMbfPC+Fbj1tX3pb2J68eqq0TdsMRe8vWEkjl+mkF5R4n8Qac6+QzzjpXOx0/fxHxc1ZjzSHfeklad1FTW+exzbjL9+cx4fvm3mW7Va337s+ZPu4hsDihm4yeMwdZafrlVfilWafxcrOKzVmp88a/dx93s5aFyxuGKn7tWxXG5aiudX1TMf1MOZ79+qDqVi/v/pyJ0W9sxnEXJyZPvB2RcvJ76sYOUdNq6Ig0nvfHO3KRccbzfQZXFv10DFPf34G0gnKnFh9KJfDN6YVOE567oraAcbmqFgvWeDeBs7v7JsGGEzpzouVlt9IPNb+kvn37mDc2KyYV6yLrc3nnfS/dqjHz3/udln2++zjq6oRi6f2JL/bjuwOnkHvOdVdw+XvflX0OQP0JylvBkFL2HS9xOQG5Goek0vmrP6ah39x1Hmeomv5pEkYs+BlLt2bjgIqrDTU+3pFja72i1oebtanStPYX8bbfiDc4oYcIfy7nfaF+vk33qehE8SWvk1VqfmnAZlB65cc0u+aathONjvuMn7Mar69Ld/Gsfi3Rl23LxhcqroSsyi5XI/2M67pqx8/I+nCXVG9dpHIe1oVr0z0WMkouVam6eq3yYQq/Cj/nurWyzpm7jlu5ME9Gv7E5oPvTMp+WqmjWJnfb4p1Y6eZGqSN/QnV1Q7am1nOttj/HSOsb3WosWJOOdIUqFVfv8FdLd2HiO56rdqyvV3Nj8/uDBT6drP/4lfupBXOKLrgdfTLQ9CqQ8FguzCtJecUYHt9K9/1YSzOuvvZf7j2Je4d3tUs2rn4kuecuoqqmDjFRyuUXb35bQ7q2AGC5sgCAvHOu68QrvRzCwB0tfv+1dQJny68gKqL+qPmzWaUbqL7Ynllkaxf/1d4TuLZna9w6qKPf2xVC2O6rWDsRBYuMsxXo1177CS+4hM68Yh1US+9hV22ldhcZJ0WqU1WbkDILLcmnXOGy3JuSW9umsXaPi93UA//hC+d6eW/ctniH7W/5eDe+em7lQVy7cBNGLPjZr7gAqGpB48jxpESwHPuHPt5rt7xUo96UWlWV6EGvGkNO6Mwngbrx781ufJ3CLsXXziw+EEJgw5EzblvQWNl1srG2cvFj39+76RXqrrWO0mG9fYn7MW8A589O6TsjHxjNtj+PWzY/vQpEnNCZqTn+LjZn6DtSotMP0csf5rbMIsz4fB/e3aiuaWjKyVK7Jn7yk1bJxSpkqqz2cFlnKy2evzrNzav9yz62DjUOKd4sk2uYCdehs6Dm7c2j6cuTFJf7OsDSRS86HG09VojHPk12u461imbx5izUCoHZE/u5XX/akp0u65MnvrsNZ8srkbdwisfYvvAwCfiGI771pnXns10q+g4ofC5E2o4guu6wfq1KfKXX5NxcQmc+CdSUYJ6aPcqdq/Ctjbc/5D/LlUmuxyxROi+pbclyKL9U8Wh706b9m32ex1NxtDm9EPFzVqNIYYRMX3hzbn7xv869hr2VlFuMkQs2Yua/9/m9La1xlQsLS/JJNjx55Ud31Qba8Pd36G/LF1/37+lKRynBfLYrDwBw+JS6cVE8xqByv0B9W3V/vLk+w+eOXHrTq7aJEzrzWsrJUiTMW290GF7Rq0Qkr9NWsw+lhP7DQcfRJ+3VCmG7Qerr+/BlCCCtkmGSda5Npxhcvxl5b2JfeVMYAICsQm2aYRqJEzrzWqAuYTeo6FGXfLxE1bbOXajEJw5jyPgiu+iCqsHLHLlKxIUVV/D0ioNuX3uyuD4x+Vr36ksVWZ2X48e4svGopX5+r8Ikykqb1qp+2dtxj25etA0L17rqoastvQoYfFOUeS0QA4QBwIzPtTtxOLZ19tU/t+fiS9kNRn9/lzU+dEX3ZM63qU7LIjxkEKWnrb1GfW0OqorCpvUYdlitpVuzMWeS+xvV2uCboowFBVctX9x1MrKt42Ey4+2ZRW6fTz3l3KHHcUagFUnOwxW4Sh8VlTVYtCEDLmbpU2VVSgEKSi/jgy1ZinX1V6qVN74nV7mefP5q70ZDNCO+KcqCRrA0H46fs9roEOx+mGpmjvJ0s8/TlcQLCq0/jhfXDz+Qmq/cg9Nd56n3NmW5rW++7KHE/NRXB3Dtwk14Y12GV6NiKr2XYKDl2PiBZsqE3qttE6NDCG/BktFDwLULN/m9DfmUZt6O9a2GYw/T79xMKVcnBK5U1+Kt9Rl2y10Nr3DghDbD4mrpuf+4v6ehBU9VYD5vV5et6qxXHCd0IwWqDt0Mnl2ZYsh+5T1EF6ypv5GnV4cVOXfvOavwAvrNXYfFm7Psll/zN+UT15sOiT8YpJzUfygIbrYow12GWbgb7zDRhLXuOsLgX/TvNLyRbZRAjbWvB1MmdBMfbxamtPzOll+pdtreKmmi7ECU0EPdGZXT3vmDb4oyxgAAf1vj3FY6q/ACautEwIZkCGUBmIcdmWddT13oD07ojAWAlpNdVFY7t8JIyitGv7lrsTPL/y7zTH9f7nU/WJqvuGMRYzqrrq2zTZCsCYXL9d052k3szfS3KV2fYZ45oTOmo4R563DJxO2ambmYssqlRaNoo0NgTBU9knlagTajH7LQoyqhE9FEIsogoiwimqPw/INElCr9+4WIBmkfar25UxPw0q0JePe+wXruhrGgZB1jhTFHHhM6EUUCWAJgEoAEAPcTUYLDarkAbhBCXA3gVQDLtA5UrnGDKEy/rjtu02BmcMYYCxVqSugjAGQJIXKEEFUAVgCYJl9BCPGLEMI6juluAJ21DZMxxkJHb52GL1GT0DsBkA/fli8tc+U3ANYqPUFEM4gomYiSi4rcjyqnhq7DejLGmE6iI/W5falmq0pZU7HpPRHdCEtCn630vBBimRAiUQiRGBcXpz5KN754fCRWzBilybYYYywQ0k7rc2NbTbPFfABdZI87AyhwXImIrgbwEYBJQoiA9W64rlebQO2KMcaCmpoSehKA3kTUnYhiANwHYJV8BSLqCuA7AA8JIY5pH6Zne/4yzojdMsZY0PBYQhdC1BDRkwDWA4gE8IkQ4ggRzZSeXwpgHoDWAD6Q6rVrhBCJ+oXtrF2z2EDujjHGgo6qnqJCiDUA1jgsWyr7+3EAj2sbGmOMMW+YsqcoY4wxZyGV0Id1a2l0CIwxZpiQSuiPX9/d6BAYY8wwIZXQr+vNTRgZY+ErpBJ6bFSk0SEwxphhQiqhx0SF1NthjDGvcAZkjLEQwQmdMcZCBCd0xhgLEZzQGWMsRHBCZ4yxEMEJnTHGQkTIJfT+HZoZHQJjjLk1qkcrXbYbcgn9t6O5+z9jLLiN6N5al+2GXEJn/hnRXZ+SA2NMfyGX0Mdf1d7oEHR1YO4tWD59uG7bf/aWPsh6bZJu22eMARCK0zL7LeQSepMGUdj7QuhOR9eycQx6t22i2/aFAKJ0mpGcMaavkPzltm1qPx3dU+N6Y+3Tow2KRr3re7XBtlk3ekiwo0IAABIHSURBVFyvdeMGdo+VxrCJjVb+aB+7zv09BgFLyaFdM/t9dGvdCF/+diQiyGN4jDFPSJ8fkqop6Mzo+yeuQ3QkoUebJoiNjkD6mQqjQ1Kla+tGissfuaYbhkoTeDSMqR9V8u5hnfH7sT3x8Y5crD10GlGRESiqqMTO2Tdh2PyNtvWeu6UPru/dBkO6tsS+48VIyS9T3E/zhtEAgDrZFWHrxjHYKp1odj0/Drcs2oryKzV+vU/GwppOVS4hm9AHd2lh99iXE+KMMT2wbFuORhF5p1OLhjhVetn2uF+HZpg2uJPTejf0iUPPuCZYcMdALLhjIC5X1SK/5BJaN2mAb2Zeg/mrj+I/M69BtKwa5bPHRuK9TZn4eEeubVnKvPFIPl6Mqzo2B2C5qpn7/WEsf3Q4buzX1rZeu2axSH15Aq5U1+K5lSlYfei0Hm+fMVWaxUaZsnChTzoP0SoXrbRoFO3za/MWTvH6NVd1srShz14wGTtm21e9NIqxH+t98kDLzV/HE1XDmEj0btcUAJAY30q6UrH/mJs3isbcqQn4esYou2Xj+rezPX5oVDfkLZxil8zlYqMj0cBFtQ5jgXBTv7ZIfXmC02/DDK7tqc9kPGHziyQ4F9HH9WuL9X8ag6du6oW/3naV0/OOdfFy06+Lt/0do8FNxOWPDses8X0BAJERBCJC2isTbM/fenVHu/VbNooBADSO8f0ia2QP/9rCavG+GfPG3KkJeOfewQBg+0W/Mm2AcQG58Mebetk9vn9EV7vHvXRq2BA2v0ilKhciQt/2TfHs+L545Np4vDKtPqmP7t0GoxWmtLNW5chvRM6d2t9uHV8mqx7QqblT65JGMVFo3jAaL07pjwiHu5EvTOmPV28fgLF947zel1aeuaUPAKBPO/1a3TBtdGrRUPd93JLQzvNKPhjTJ852kz+xW0s0aWBfiOneprFf29ej70XnlvbHO6FjM6z83TW2xzrdEw2fhC7XI075C/DwNfG2v3u1bYJ2zWKRt3AKvvvDtQCAJQ8Mxde/G4XfjemBp27qbVtXXlXx7e+vwRePjwQA9GvfVHE/6a9OVB1rykvj8fjoHk7LG8VE4aFR3UB6fTNUaNcsFh8+OBRf/naU55WZoYb6UMjwlmOi1cryR4fjuVssV68dWjhfNXdp5d/JaskDQ/16vZLBXeyPN8H+xNEs1vfqXHfCJqFb017PuMaYOyUBgH1rEUfyS6KhXVsi9eXxmHJ1BzSIisTzk/ujsezL27FFQxybPwnb/+9GDOvWCrHRlu0qXQq+OKU/YqMj8c69g3Hn0E5o6Uc9vRYW3DEQgxxuIHtj0sAOaNOkATY8M8a27NVpztVXeurbTvnEqYXX7hiA1+6wfI5tmzbwsLb3oiII/3vyertlI7u3clkY8Fbfdk3xn5nX4M27r9Zke+5E+tmmtWms8gkhMoLw+OjuSH91omI1qGMzXnfGe7iKcCztH3p5PNo0iXH7mn88NMxpWd/2TfHAyK4Ka1voNV1m+CR02Xfthj5xeHpcb8XEs/v5cbhnWGfcP9z+w1A6o/7joWH4250DAVg+oC6t7JscjujeCodeHo/5tw+wNQe0lrZvH9IJi3412O8fgb8eGNkVPzxxnd/bsV7SN4yOxEOyK51AcNXU019/v2cQHhzZDdMGd0JMZAQW3jUQeQun4KqO6gaAU5OUbxvcEQM7N7c97t6mMT7/zUis+9MYN69Sb9aEvhgeX1/IACxVGI7aNHGdFG+S3Rj/8EFLaTY2OgI//vF6PDWuN47Nn4SHRnXDXyb3x7o/+d7fI8LN1SYR2d5DJ6k6w3rVYf0JzZrQ1/ZdbtOkAfb8ZRyevNG+LvvDX9cnX+t9MPlv8KdnxuDd+wbbHjeNjUaH5pb9LbhjIH4/tqdTbBOuaq94P2n+tAG4c6h9y7QpAztg8QNDXL5Pf4VNQu/cshG6tW6El269ChERhGdu6YMWjZzPvO2bx+LNewY51VkrmXBVe6ebHY6axkbj16O6Ye8L4xSrWm4dZLnZ2biB+e7Uy1mbYRlxfnrrnkG6bHfywA4ALFUJx16bhJv6WUp3rvJON4cTy6fTR7jc9j8eGoYb+8bhOelGuNXmP49VVXrb/OexeGFyf4/rRUU6B/vZY67jcrThmTH45NH6oSbG9W+H/5vYFz88cT0GdGqOZ2/pg5ioCLx6+wC0ahyDfu19H+1U6bjKk6tV/w7NsOm5G/D7G3pKryPkLZyCJ27shQ7NY23batcsFn+e0NeuxZk8eb9061XIWzgFrRpb8kDH5rGIioxwah5s7Ww3oFMzzJ7Yz3IcpJPc0K6ur24jIggNoiLt3tuSB4diqkMDBy2FTUKPjY7E1lk3KpZOAqFBVKRdKcnqxSkJSHlpPBr50VolGNRJHSWMqNO3Xv1o6ZFrurmskrvGoXXQq7cPwOe/GeF0tdWiUTSulkrfb9xlqfKwnsCHx7fC8ukj3N6sTHlpPO4c0glP3tgLK6QmphOvao+PHk5E9zaN8dsxzvdWAGDjs/Wl+55x9VWHPeIaO1UfyMcFurm/55uaMVER+MPYXujr5upjeHx9/XHH5rEYopD0blD4HTqW0O8f0VWx7wUA9IhroqrQZfX+/UNszXRHdm+F9++3LyXnLZyCX56vHzLklzk3IemFmwEA3VpZqmGsv9G8hVPw8aPD8cZdV2P5o+pPjoFg7iwSAiIjSJeEFGixUknkMVlzTjVenNIfzRtGY9Y3qejTrgnWPT0GPf6yRocI6711zyD8+T8pbtf5q5umcLMn9sMdQzrj5VVH8NodA2zt/j96OBFfJ53EoC4t8P6mLMRERmDVk9fjTNkVtG8ei18N7wIATsnEleYNo7Ho3voSqlLfhm9mXoNDp8ow3cWQDvJqwE3PjbX9PWtCX1RcqZFVHwm8f/8QfLM/H7de3QGDX/lJVYxKPntsJPrPWwcA2DH7Jry98RgOnCi1Pf/8pH549Lp49H1xnd3r5PnZl34c7lhPpADwtay1iSsdZSfa1+++GrcP6eTU1ND6eQKWaqyNR886bSdeumpr38x1E2gtcUJnmoiJivDpR/jrUd1QJwS2HCvCvKkJdqWulo2iUXKpWtV21CRpq7uHdUavtk1w+5KdTs91atEQ/33iWrevj4qMsDRDm2mfGHrENcHzUjWItboGsFTjebLhmTEovlilJnw7ifGtkBjvfbO7J6S65aKKStuyhjGReGhUN7v1rKXmDs1jcbrsiqptN4yJxPW92mBH1jlERBD+dHMf3DeiK15bnYZLVbX4nVRV4viZLXs4EXd+8IvX78WO9PVpoOFNxyYNojw2yVz8wBAUlldizJub7Zb/dnQPXNWxOa5XaAKtB07oLGCaNohCRaWlm/YDI7vitdsH2Kpo5E3HRnZvhT25xVjz9GiMX7TN9hoASOjQDGmny522ffewzhjduw1GLvjZ5f5/emaM7bWDu7RA+2axOFN+BVtnjcUNb25B45hI7Jxzkybv1Vt9NG6p0yAqApU1dR7Xs1YTWeuRrfb8ZRx+OHgKPaUmvt8/cR2OnVU/HtI/H07E+YuVtn10atEQHzxo3xrk7mGd0bttE0yTTqyDO/ve2soqrkkDzJrQF1NkJ9RAiI2OVLw5HxFBAUvmACd0ppPRvdtge+Y5tGkSg3MXLCXP/5vYF3N/OIJts2502zJlxYxRqKqtQ4OoSHzz+2sx94fDuKFPHMb0jsPAzs0RP2e14uus/Qasz389YxTuXbYb86YmYGzfOPSIa2KrHgGAHbNvRE2dsCW1Pzi0iDCzLbPGIr/kssf1WjWOwd/uHOjUQa1ds1jMGNPT7nE7L6oNGsZEonOM59ZH1iaz8s54/kwjSUS2qw8j7H1hnF7jbqlCwqC9JyYmiuTkZEP2zfR3uaoWB0+WomXjaEx8ZzseHNkV828fgIrKGr87VVyorEHZ5WpERRCuVNeiW2v7tsPWhK51PSzTx5myK2jeMBoNYyKxK/s8+rVvipaN3bf9DmdEtE8Ikaj4HCd0prd9x0swoFMzWxMuvaUVlGNP7nmXNwoZMzN3CV3VnQMimkhEGUSURURzFJ4nInpPej6ViLTvS8tMa1i3lgFL5oBl3AxO5iwceUzoRBQJYAmASQASANxPRAkOq00C0Fv6NwPAhxrHyRhjzAM1JfQRALKEEDlCiCoAKwBMc1hnGoDPhMVuAC2IKLC3mRljLMypSeidAJyUPc6Xlnm7DmOMMR2pSehK/Wsd76SqWQdENIOIkokouaioSE18jDHGVFKT0PMBdJE97gygwId1IIRYJoRIFEIkxsUZNzEDY4yFIjUJPQlAbyLqTkQxAO4DsMphnVUAHpZau4wCUCaE4NmDGWMsgDz2FBVC1BDRkwDWA4gE8IkQ4ggRzZSeXwpgDYDJALIAXAIwXb+QGWOMKVHV9V8IsQaWpC1ftlT2twDwhLahMcYY84ZhPUWJqAjAcR9f3gbAOQ3D0ZNZYuU4tWeWWDlObekdZzchhOJNSMMSuj+IKNlV19dgY5ZYOU7tmSVWjlNbRsYZNjMWMcZYqOOEzhhjIcKsCX2Z0QF4wSyxcpzaM0usHKe2DIvTlHXojDHGnJm1hM4YY8wBJ3TGGAsRpkvonibbCMD+uxDRZiI6SkRHiOhpafnLRHSKiA5K/ybLXvO8FG8GEU2QLR9GRIek594j64zJ2sWaJ23/IBElS8taEdFPRJQp/d9Stn7A4ySivrJjdpCIyonoT8FyPInoEyIqJKLDsmWaHUMiakBEX0vL9xBRvIZxvklE6dKkM/8lohbS8ngiuiw7tktlrzEiTs0+a63idBPr17I484jooLTcsGNqRwhhmn+wDD2QDaAHgBgAKQASAhxDBwBDpb+bAjgGy8QfLwP4s8L6CVKcDQB0l+KPlJ7bC+AaWEarXAtgksax5gFo47DsDQBzpL/nAHjd6DgdPt8zALoFy/EEMAbAUACH9TiGAP4AYKn0930AvtYwzvEAoqS/X5fFGS9fz2E7RsSp2WetVZyuYnV4/u8A5hl9TOX/zFZCVzPZhq6EEKeFEPulvysAHIX7sd+nAVghhKgUQuTCMt7NCLJMANJMCLFLWD7RzwDcrnP41nj+Jf39L9k+gyHOcQCyhRDuehAHNE4hxDYAxQoxaHUM5dv6BsA4X64slOIUQmwQQtRID3fDMgqqS0bF6YZhx9NTrNI2fwXgK3fbCFSsVmZL6EE1kYZ0iTQEwB5p0ZPS5e0nsstwVzF3kv52XK4lAWADEe0johnSsnZCGglT+r9tEMRpdR/sfyDBdjyttDyGttdIybcMQGsdYn4MltKhVXciOkBEW4lotCwWo+LU6rMO1PEcDeCsECJTtszwY2q2hK5qIo1AIKImAL4F8CchRDks86j2BDAYwGlYLscA1zEH4r1cJ4QYCsucr08Q0Rg36xoZJ8gyNPNtAP4jLQrG4+mJL7HpHjcRvQCgBsAX0qLTALoKIYYAeBbAl0TUzMA4tfysA/U9uB/2hY+gOKZmS+iqJtLQGxFFw5LMvxBCfAcAQoizQohaIUQdgH/CUj0EuI45H/aXwJq/FyFEgfR/IYD/SjGdlS4DrZeDhUbHKZkEYL8Q4qwUc9AdTxktj6HtNUQUBaA51FdJeEREjwCYCuBB6ZIfUhXGeenvfbDUTfcxKk6NP2tdj6dsu3cC+Fr2HoLimJotoauZbENXUh3XxwCOCiEWyZbLJ8W+A4D1zvgqAPdJd7S7A+gNYK90qV5BRKOkbT4M4AcN42xMRE2tf8Nyg+ywFM8j0mqPyPZpSJwydiWeYDueDrQ8hvJt3Q1gkzXx+ouIJgKYDeA2IcQl2fI4IoqU/u4hxZljYJxafta6xSlzM4B0IYStKiVojqm/d1UD/Q+WiTSOwXIGfMGA/V8Py2VRKoCD0r/JAD4HcEhavgpAB9lrXpDizYCs5QWARFi+vNkAFkPquatRnD1gaSGQAuCI9VjBUkf3M4BM6f9WRsYpbb8RgPMAmsuWBcXxhOUkcxpANSwlqt9oeQwBxMJSzZQFS2uIHhrGmQVLHa31e2ptUXGX9J1IAbAfwK0Gx6nZZ61VnK5ilZZ/CmCmw7qGHVP5P+76zxhjIcJsVS6MMcZc4ITOGGMhghM6Y4yFCE7ojDEWIjihM8ZYiOCEzhhjIYITOmOMhYj/ByYvvVcWFV4OAAAAAElFTkSuQmCC\n",
|
||
"text/plain": [
|
||
"<Figure size 432x288 with 1 Axes>"
|
||
]
|
||
},
|
||
"metadata": {
|
||
"needs_background": "light"
|
||
},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"plt.plot(loss_value)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 12,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"file_path=\"./net_trainned6000\"\n",
|
||
"save_net(file_path,net)"
|
||
]
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "Python 3",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.8.5"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 4
|
||
}
|