629 lines
40 KiB
Text
629 lines
40 KiB
Text
|
{
|
|||
|
"cells": [
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 1,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"#Tous les codes sont basés sur l'environnement suivant\n",
|
|||
|
"#python 3.7\n",
|
|||
|
"#opencv 3.1.0\n",
|
|||
|
"#pytorch 1.4.0\n",
|
|||
|
"\n",
|
|||
|
"import torch\n",
|
|||
|
"from torch.autograd import Variable\n",
|
|||
|
"import torch.nn as nn\n",
|
|||
|
"import torch.nn.functional as F\n",
|
|||
|
"import cv2\n",
|
|||
|
"import matplotlib.pyplot as plt\n",
|
|||
|
"import numpy as np\n",
|
|||
|
"import random\n",
|
|||
|
"import math\n",
|
|||
|
"import pickle\n",
|
|||
|
"import random\n",
|
|||
|
"from PIL import Image\n",
|
|||
|
"import sys\n",
|
|||
|
"from IPython.display import clear_output\n",
|
|||
|
"from datetime import datetime"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 3,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"#Les fonctions dans ce bloc ne sont pas utilisées par le réseau, mais certaines fonctions d'outils\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
"def tensor_imshow(im_tensor,cannel):\n",
|
|||
|
" b,c,h,w=im_tensor.shape\n",
|
|||
|
" if c==1:\n",
|
|||
|
" plt.imshow(im_tensor.squeeze().detach().numpy())\n",
|
|||
|
" else:\n",
|
|||
|
" plt.imshow(im_tensor.squeeze().detach().numpy()[cannel,:])\n",
|
|||
|
"\n",
|
|||
|
"# Obtenez des données d'entraînement\n",
|
|||
|
"# frag,vt=get_training_fragment(frag_size,image)\n",
|
|||
|
"# frag est un patch carrée de taille (frag_size*frag_size) a partir du image(Son emplacement est aléatoire)\n",
|
|||
|
"# vt est la vérité terrain de la forme Dirac.\n",
|
|||
|
"def get_training_fragment(frag_size,im):\n",
|
|||
|
" h,w,c=im.shape\n",
|
|||
|
" n=random.randint(0,int(h/frag_size)-1)\n",
|
|||
|
" m=random.randint(0,int(w/frag_size)-1) \n",
|
|||
|
" shape=frag_size/4\n",
|
|||
|
" vt_h=math.ceil((h+1)/shape)\n",
|
|||
|
" vt_w=math.ceil((w+1)/shape)\n",
|
|||
|
" vt=np.zeros([vt_h,vt_w])\n",
|
|||
|
" vt_h_po=round((vt_h-1)*(n*frag_size/(h-1)+(n+1)*frag_size/(h-1))/2)\n",
|
|||
|
" vt_w_po=round((vt_w-1)*(m*frag_size/(w-1)+(m+1)*frag_size/(w-1))/2)\n",
|
|||
|
" vt[vt_h_po,vt_w_po]=1\n",
|
|||
|
" vt = np.float32(vt)\n",
|
|||
|
" vt=torch.from_numpy(vt.reshape(1,1,vt_h,vt_w))\n",
|
|||
|
" \n",
|
|||
|
" return im[n*frag_size:(n+1)*frag_size,m*frag_size:(m+1)*frag_size,:],vt\n",
|
|||
|
"\n",
|
|||
|
"# Cette fonction convertit l'image en variable de type Tensor.\n",
|
|||
|
"# Toutes les données de calcul du réseau sont de type Tensor\n",
|
|||
|
"# Img.shape=[Height,Width,Channel]\n",
|
|||
|
"# Tensor.shape=[Batch,Channel,Height,Width]\n",
|
|||
|
"def img2tensor(im):\n",
|
|||
|
" im=np.array(im,dtype=\"float32\")\n",
|
|||
|
" tensor_cv = torch.from_numpy(np.transpose(im, (2, 0, 1)))\n",
|
|||
|
" im_tensor=tensor_cv.unsqueeze(0)\n",
|
|||
|
" return tensor_cv\n",
|
|||
|
"\n",
|
|||
|
"# Trouvez les coordonnées de la valeur maximale dans une carte de corrélation\n",
|
|||
|
"# x,y=show_coordonnee(carte de corrélation)\n",
|
|||
|
"def show_coordonnee(position_pred):\n",
|
|||
|
" map_corre=position_pred.squeeze().detach().numpy()\n",
|
|||
|
" h,w=map_corre.shape\n",
|
|||
|
" max_value=map_corre.max()\n",
|
|||
|
" coordonnee=np.where(map_corre==max_value)\n",
|
|||
|
" return coordonnee[0].mean()/h,coordonnee[1].mean()/w\n",
|
|||
|
"\n",
|
|||
|
"# Filtrer les patchs en fonction du nombre de pixels noirs dans le patch\n",
|
|||
|
"# Si seuls les pixels non noirs sont plus grands qu'une certaine proportion(seuillage), revenez à True, sinon False\n",
|
|||
|
"def test_fragment32_32(frag,seuillage):\n",
|
|||
|
" a=frag[:,:,0]+frag[:,:,1]+frag[:,:,2]\n",
|
|||
|
" mask = (a == 0)\n",
|
|||
|
" arr_new = a[mask]\n",
|
|||
|
" if arr_new.size/a.size<=(1-seuillage):\n",
|
|||
|
" return True\n",
|
|||
|
" else:\n",
|
|||
|
" return False\n",
|
|||
|
" \n",
|
|||
|
"# Ces deux fonctions permettent de sauvegarder le réseau dans un fichier\n",
|
|||
|
"# ou de load le réseau stocké à partir d'un fichier\n",
|
|||
|
"def save_net(file_path,net):\n",
|
|||
|
" pkl_file = open(file_path, 'wb')\n",
|
|||
|
" pickle.dump(net,pkl_file)\n",
|
|||
|
" pkl_file.close()\n",
|
|||
|
"def load_net(file_path): \n",
|
|||
|
" pkl_file = open(file_path, 'rb')\n",
|
|||
|
" net= pickle.load(pkl_file)\n",
|
|||
|
" pkl_file.close()\n",
|
|||
|
" return net"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 4,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"# Créer un poids de type DeepMatch comme valeur initiale de Conv1 (non obligatoire)\n",
|
|||
|
"def ini():\n",
|
|||
|
" kernel=torch.zeros([8,3,3,3])\n",
|
|||
|
" array_0=np.array([[1,2,1],[0,0,0],[-1,-2,-1]],dtype='float32')\n",
|
|||
|
" array_1=np.array([[2,1,0],[1,0,-1],[0,-1,-2]],dtype='float32')\n",
|
|||
|
" array_2=np.array([[1,0,-1],[2,0,-2],[1,0,-1]],dtype='float32')\n",
|
|||
|
" array_3=np.array([[0,-1,-2],[1,0,-1],[2,1,0]],dtype='float32')\n",
|
|||
|
" array_4=np.array([[-1,-2,-1],[0,0,0],[1,2,1]],dtype='float32')\n",
|
|||
|
" array_5=np.array([[-2,-1,0],[-1,0,1],[0,1,2]],dtype='float32')\n",
|
|||
|
" array_6=np.array([[-1,0,1],[-2,0,2],[-1,0,1]],dtype='float32')\n",
|
|||
|
" array_7=np.array([[0,1,2],[-1,0,1],[-2,-1,0]],dtype='float32')\n",
|
|||
|
" for i in range(3):\n",
|
|||
|
" kernel[0,i,:]=torch.from_numpy(array_0)\n",
|
|||
|
" kernel[1,i,:]=torch.from_numpy(array_1)\n",
|
|||
|
" kernel[2,i,:]=torch.from_numpy(array_2)\n",
|
|||
|
" kernel[3,i,:]=torch.from_numpy(array_3)\n",
|
|||
|
" kernel[4,i,:]=torch.from_numpy(array_4)\n",
|
|||
|
" kernel[5,i,:]=torch.from_numpy(array_5)\n",
|
|||
|
" kernel[6,i,:]=torch.from_numpy(array_6)\n",
|
|||
|
" kernel[7,i,:]=torch.from_numpy(array_7)\n",
|
|||
|
" return torch.nn.Parameter(kernel,requires_grad=True) \n",
|
|||
|
"\n",
|
|||
|
"# Calculer le poids initial de la couche convolutive add\n",
|
|||
|
"# n, m signifie qu'il y a n * m sous-patches dans le patch d'entrée\n",
|
|||
|
"# Par exemple, le patch d'entrée est 16 * 16, pour les patchs 4 * 4 de la première couche, n = 4, m = 4\n",
|
|||
|
"# pour les patchs 8 * 8 de la deuxième couche, n = 2, m = 2\n",
|
|||
|
"def kernel_add_ini(n,m):\n",
|
|||
|
" input_canal=int(n*m)\n",
|
|||
|
" output_canal=int(n/2)*int(m/2)\n",
|
|||
|
" for i in range(int(n/2)):\n",
|
|||
|
" for j in range(int(m/2)):\n",
|
|||
|
" kernel_add=np.zeros([1,input_canal],dtype='float32')\n",
|
|||
|
" kernel_add[0,i*2*m+j*2]=1\n",
|
|||
|
" kernel_add[0,i*2*m+j*2+1]=1\n",
|
|||
|
" kernel_add[0,(i*2+1)*m+j*2]=1\n",
|
|||
|
" kernel_add[0,(i*2+1)*m+j*2+1]=1\n",
|
|||
|
" if i==0 and j==0:\n",
|
|||
|
" add=torch.from_numpy(kernel_add.reshape(1,input_canal,1,1))\n",
|
|||
|
" else:\n",
|
|||
|
" add_=torch.from_numpy(kernel_add.reshape(1,input_canal,1,1))\n",
|
|||
|
" add=torch.cat((add,add_),0)\n",
|
|||
|
" return torch.nn.Parameter(add,requires_grad=False) \n",
|
|||
|
"\n",
|
|||
|
"# Calculer le poids initial de la couche convolutive shift\n",
|
|||
|
"# shift+add Peut réaliser l'étape de l'agrégation\n",
|
|||
|
"# Voir ci-dessus pour les paramètres n et m. \n",
|
|||
|
"# Pour des étapes plus détaillées, veuillez consulter mon rapport de stage\n",
|
|||
|
"def kernel_shift_ini(n,m):\n",
|
|||
|
" input_canal=int(n*m)\n",
|
|||
|
" output_canal=int(n*m)\n",
|
|||
|
" \n",
|
|||
|
" kernel_shift=torch.zeros([output_canal,input_canal,3,3])\n",
|
|||
|
" \n",
|
|||
|
" array_0=np.array([[1,0,0],[0,0,0],[0,0,0]],dtype='float32')\n",
|
|||
|
" array_1=np.array([[0,0,1],[0,0,0],[0,0,0]],dtype='float32')\n",
|
|||
|
" array_2=np.array([[0,0,0],[0,0,0],[1,0,0]],dtype='float32')\n",
|
|||
|
" array_3=np.array([[0,0,0],[0,0,0],[0,0,1]],dtype='float32')\n",
|
|||
|
" \n",
|
|||
|
" kernel_shift_0=torch.from_numpy(array_0)\n",
|
|||
|
" kernel_shift_1=torch.from_numpy(array_1)\n",
|
|||
|
" kernel_shift_2=torch.from_numpy(array_2)\n",
|
|||
|
" kernel_shift_3=torch.from_numpy(array_3)\n",
|
|||
|
" \n",
|
|||
|
" \n",
|
|||
|
" for i in range(n):\n",
|
|||
|
" for j in range(m):\n",
|
|||
|
" if i==0 and j==0:\n",
|
|||
|
" kernel_shift[0,0,:]=kernel_shift_0\n",
|
|||
|
" else:\n",
|
|||
|
" if i%2==0 and j%2==0:\n",
|
|||
|
" kernel_shift[i*m+j,i*m+j,:]=kernel_shift_0\n",
|
|||
|
" if i%2==0 and j%2==1:\n",
|
|||
|
" kernel_shift[i*m+j,i*m+j,:]=kernel_shift_1\n",
|
|||
|
" if i%2==1 and j%2==0:\n",
|
|||
|
" kernel_shift[i*m+j,i*m+j,:]=kernel_shift_2\n",
|
|||
|
" if i%2==1 and j%2==1:\n",
|
|||
|
" kernel_shift[i*m+j,i*m+j,:]=kernel_shift_3\n",
|
|||
|
" \n",
|
|||
|
" return torch.nn.Parameter(kernel_shift,requires_grad=False) \n",
|
|||
|
"\n",
|
|||
|
"# Trouvez le petit patch(4 * 4) dans la n ème ligne et la m ème colonne du patch d'entrée\n",
|
|||
|
"# Ceci est utilisé pour calculer la convolution et obtenir la carte de corrélation\n",
|
|||
|
"def get_patch(fragment,psize,n,m):\n",
|
|||
|
" return fragment[:,:,n*psize:(n+1)*psize,m*psize:(m+1)*psize]\n",
|
|||
|
"\n",
|
|||
|
"###################################################################################################################\n",
|
|||
|
"class Net(nn.Module):\n",
|
|||
|
" def __init__(self,frag_size,psize):\n",
|
|||
|
" super(Net, self).__init__()\n",
|
|||
|
" \n",
|
|||
|
" h_fr=frag_size\n",
|
|||
|
" w_fr=frag_size\n",
|
|||
|
" \n",
|
|||
|
" n=int(h_fr/psize) # n*m patches dans le patch d'entrée\n",
|
|||
|
" m=int(w_fr/psize)\n",
|
|||
|
" \n",
|
|||
|
" self.conv1 = nn.Conv2d(3,8,kernel_size=3,stride=1,padding=1)\n",
|
|||
|
" # Si vous souhaitez initialiser Conv1 avec les poids de DeepMatch, exécutez la ligne suivante\n",
|
|||
|
" # self.conv1.weight=ini()\n",
|
|||
|
" self.Relu = nn.ReLU(inplace=True)\n",
|
|||
|
" self.maxpooling=nn.MaxPool2d(3,stride=2, padding=1)\n",
|
|||
|
" \n",
|
|||
|
" self.shift1=nn.Conv2d(n*m,n*m,kernel_size=3,stride=1,padding=1)\n",
|
|||
|
" self.shift1.weight=kernel_shift_ini(n,m)\n",
|
|||
|
" self.add1 = nn.Conv2d(n*m,int(n/2)*int(m/2),kernel_size=1,stride=1,padding=0)\n",
|
|||
|
" self.add1.weight=kernel_add_ini(n,m)\n",
|
|||
|
" \n",
|
|||
|
" n=int(n/2)\n",
|
|||
|
" m=int(m/2)\n",
|
|||
|
" if n>=2 and m>=2:# Si n=m=1,Notre réseau n'a plus besoin de plus de couches pour agréger les cartes de corrélation\n",
|
|||
|
" self.shift2=nn.Conv2d(n*m,n*m,kernel_size=3,stride=1,padding=1)\n",
|
|||
|
" self.shift2.weight=kernel_shift_ini(n,m)\n",
|
|||
|
" self.add2 = nn.Conv2d(n*m,int(n/2)*int(m/2),kernel_size=1,stride=1,padding=0)\n",
|
|||
|
" self.add2.weight=kernel_add_ini(n,m)\n",
|
|||
|
" \n",
|
|||
|
" n=int(n/2)\n",
|
|||
|
" m=int(m/2)\n",
|
|||
|
" if n>=2 and m>=2:\n",
|
|||
|
" self.shift3=nn.Conv2d(n*m,n*m,kernel_size=3,stride=1,padding=1)\n",
|
|||
|
" self.shift3.weight=kernel_shift_ini(n,m)\n",
|
|||
|
" self.add3 = nn.Conv2d(n*m,int(n/2)*int(m/2),kernel_size=1,stride=1,padding=0)\n",
|
|||
|
" self.add3.weight=kernel_add_ini(n,m)\n",
|
|||
|
" \n",
|
|||
|
" def get_descripteur(self,img,using_cuda):\n",
|
|||
|
" # Utilisez Conv1 pour calculer le descripteur,\n",
|
|||
|
" descripteur_img=self.Relu(self.conv1(img))\n",
|
|||
|
" b,c,h,w=descripteur_img.shape\n",
|
|||
|
" couche_constante = 0.5 * torch.ones([b, 1, h, w])\n",
|
|||
|
" if using_cuda:\n",
|
|||
|
" couche_constante=couche_constante.cuda()\n",
|
|||
|
" # Ajouter une couche constante pour éviter la division par 0 lors de la normalisation\n",
|
|||
|
" descripteur_img = torch.cat((descripteur_img,couche_constante),1)\n",
|
|||
|
" # la normalisation\n",
|
|||
|
" descripteur_img_norm = F.normalize(descripteur_img)\n",
|
|||
|
" return descripteur_img_norm\n",
|
|||
|
" \n",
|
|||
|
" def forward(self,img,frag,using_cuda):\n",
|
|||
|
" psize=4\n",
|
|||
|
" # Utilisez Conv1 pour calculer le descripteur,\n",
|
|||
|
" descripteur_input2=self.get_descripteur(frag,using_cuda)\n",
|
|||
|
" descripteur_input1=self.get_descripteur(img,using_cuda)\n",
|
|||
|
" \n",
|
|||
|
" b,c,h,w=frag.shape\n",
|
|||
|
" n=int(h/psize)\n",
|
|||
|
" m=int(w/psize)\n",
|
|||
|
" \n",
|
|||
|
" db,dc,dh,dw = descripteur_input1.shape\n",
|
|||
|
" \n",
|
|||
|
" #######################################\n",
|
|||
|
" # Calculer la carte de corrélation par convolution pour les n*m patchs plus petit.\n",
|
|||
|
" for i in range(n):\n",
|
|||
|
" for j in range(m):\n",
|
|||
|
" if i==0 and j==0:\n",
|
|||
|
" ##HAD TO CHANGE THIS LINE BECAUSE OF CONVOLUTION DIMENSION FOR BATCHES\n",
|
|||
|
" map_corre=F.conv2d(descripteur_input1.view(1,db*dc,dh,dw),get_patch(descripteur_input2,psize,i,j),padding=2,groups=db)\n",
|
|||
|
"\n",
|
|||
|
" map_corre=map_corre.view(db,1,map_corre.size(2),map_corre.size(3))\n",
|
|||
|
" else:\n",
|
|||
|
" a=F.conv2d(descripteur_input1.view(1,db*dc,dh,dw),get_patch(descripteur_input2,psize,i,j),padding=2, groups=db)\n",
|
|||
|
" a=a.view(db,1,a.size(2),a.size(3))\n",
|
|||
|
" map_corre=torch.cat((map_corre,a),1)\n",
|
|||
|
" \n",
|
|||
|
" ########################################\n",
|
|||
|
" # Étape de polymérisation\n",
|
|||
|
" map_corre=self.maxpooling(map_corre)\n",
|
|||
|
" map_corre=self.shift1(map_corre)\n",
|
|||
|
" map_corre=self.add1(map_corre)\n",
|
|||
|
" \n",
|
|||
|
" #########################################\n",
|
|||
|
" # Répétez l'étape d'agrégation jusqu'à obtenir le graphique de corrélation du patch d'entrée\n",
|
|||
|
" n=int(n/2)\n",
|
|||
|
" m=int(m/2)\n",
|
|||
|
" if n>=2 and m>=2:\n",
|
|||
|
" map_corre=self.maxpooling(map_corre)\n",
|
|||
|
" map_corre=self.shift2(map_corre)\n",
|
|||
|
" map_corre=self.add2(map_corre)\n",
|
|||
|
" \n",
|
|||
|
" \n",
|
|||
|
" n=int(n/2)\n",
|
|||
|
" m=int(m/2)\n",
|
|||
|
" if n>=2 and m>=2:\n",
|
|||
|
" map_corre=self.maxpooling(map_corre)\n",
|
|||
|
" map_corre=self.shift3(map_corre)\n",
|
|||
|
" map_corre=self.add3(map_corre)\n",
|
|||
|
" \n",
|
|||
|
" \n",
|
|||
|
" #b,c,h,w=map_corre.shape\n",
|
|||
|
" # Normalisation de la division par maximum\n",
|
|||
|
" map_corre=map_corre/map_corre.max()\n",
|
|||
|
" # Normalisation SoftMax\n",
|
|||
|
" #map_corre=(F.softmax(map_corre.reshape(1,1,h*w,1),dim=2)).reshape(b,c,h,w)\n",
|
|||
|
" return map_corre"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"# Dataset and Dataloader"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 5,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"import os\n",
|
|||
|
"import re\n",
|
|||
|
"from PIL import Image\n",
|
|||
|
"from torchvision import transforms\n",
|
|||
|
"from torch.utils.data import Dataset, DataLoader\n",
|
|||
|
"from time import time\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
"class FragmentDataset(Dataset):\n",
|
|||
|
" def __init__(\n",
|
|||
|
" self,\n",
|
|||
|
" fragments_path, \n",
|
|||
|
" train, \n",
|
|||
|
" frags_transform=transforms.ToTensor(),\n",
|
|||
|
" fresques_transform=None,\n",
|
|||
|
" vts_transform=None,\n",
|
|||
|
" ):\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" Parameters\n",
|
|||
|
" ----------\n",
|
|||
|
" fragments_path: str\n",
|
|||
|
" Path to root flder with fragments folders.\n",
|
|||
|
" train: boolean\n",
|
|||
|
" True for train set (__dev__) False for test (__bench__)\n",
|
|||
|
" frags_transform: torchvision.transform\n",
|
|||
|
" Tranform to apply to all fragment images. Default: ToTensor()\n",
|
|||
|
" fresques_transform: torchvision.transform\n",
|
|||
|
" Transform to apply to all fresque images. frags_transform if None.\n",
|
|||
|
" vts_transform: transform to apply to all vts images. Default: ToTensor().\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" self.base_path = fragments_path\n",
|
|||
|
" self.frags_transform = frags_transform\n",
|
|||
|
" self.fresques_transform = fresques_transform if fresques_transform else frags_transform\n",
|
|||
|
" self.fragments_list = []\n",
|
|||
|
" self.vts_transform = vts_transform\n",
|
|||
|
" \n",
|
|||
|
" # To separate between train (dev) and test fragments(bench)\n",
|
|||
|
" self.match_expr = \"_dev_\" if train else \"_bench_\"\n",
|
|||
|
" \n",
|
|||
|
" fragments_path = os.path.join(self.base_path, \"fragments\")\n",
|
|||
|
" for fresque_dir in os.listdir(fragments_path):\n",
|
|||
|
" current_path = os.path.join(fragments_path, fresque_dir)\n",
|
|||
|
" \n",
|
|||
|
" if \"fresque\" in current_path: \n",
|
|||
|
" # Avoids looking at extra files in the dirs.\n",
|
|||
|
" \n",
|
|||
|
" # Get path to current fresque (ie: ..path/fresque0.ppm).\n",
|
|||
|
" fresque_name = current_path.split(\"/\")[-1] + \".ppm\"\n",
|
|||
|
" full_fresque_path = os.path.join(self.base_path, fresque_name) \n",
|
|||
|
" \n",
|
|||
|
" # Get path to every fragment for that fresque (ie: ..path/fresque0/frag_bench_000.ppm)\n",
|
|||
|
" all_fragments_fresque = sorted(os.listdir(current_path))\n",
|
|||
|
" \n",
|
|||
|
" #Get path to every vt for that fresque (ie: ..path/fresque0/vt/frag_bench_000.ppm))\n",
|
|||
|
" vts_path = os.path.join(current_path, \"vt\")\n",
|
|||
|
" all_vts_fresque = sorted(os.listdir(vts_path))\n",
|
|||
|
" \n",
|
|||
|
" # Keep fragments that belong in that set (Train | Test) \n",
|
|||
|
" # group them with the full fresque path (tuple)\n",
|
|||
|
" all_fragments_fresque = [\n",
|
|||
|
" (os.path.join(current_path, frag_path), full_fresque_path, os.path.join(vts_path, vt_path))\n",
|
|||
|
" for frag_path, vt_path in zip(all_fragments_fresque, all_vts_fresque)\n",
|
|||
|
" if re.search(self.match_expr, frag_path) and re.search(self.match_expr, vt_path)\n",
|
|||
|
" ]\n",
|
|||
|
" \n",
|
|||
|
" self.fragments_list.extend(all_fragments_fresque)\n",
|
|||
|
" \n",
|
|||
|
" def __len__(self):\n",
|
|||
|
" return len(self.fragments_list)\n",
|
|||
|
" \n",
|
|||
|
" def __getitem__(self, idx):\n",
|
|||
|
" # Loads the fragment and the full fresque as a tensor.\n",
|
|||
|
" fragment = Image.open(self.fragments_list[idx][0])\n",
|
|||
|
" fresque = Image.open(self.fragments_list[idx][1])\n",
|
|||
|
" \n",
|
|||
|
" with open(self.fragments_list[idx][2],'r') as f:\n",
|
|||
|
" data_vt_raw = f.readlines()\n",
|
|||
|
" data_vt = [int(d.rstrip('\\r\\n')) for d in data_vt_raw]\n",
|
|||
|
" \n",
|
|||
|
" # Construct vt\n",
|
|||
|
" vt = np.zeros((int(data_vt[0]/4)+1,int(data_vt[1]/4)+1))\n",
|
|||
|
" vt[int(data_vt[2]/4),int(data_vt[3]/4)] = 1\n",
|
|||
|
" vt = np.float32(vt)\n",
|
|||
|
" vt = torch.from_numpy(vt.reshape(1,int(data_vt[0]/4)+1,int(data_vt[1]/4)+1))\n",
|
|||
|
" \n",
|
|||
|
" return self.frags_transform(fragment), self.fresques_transform(fresque), vt"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"## Usage"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 6,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"fresques_tnsf = transforms.Compose([\n",
|
|||
|
" transforms.Resize((1000, 1000)),\n",
|
|||
|
" transforms.ToTensor()\n",
|
|||
|
"])\n",
|
|||
|
"\n",
|
|||
|
"train = FragmentDataset(fragments_path=\"training_data_random_shift_color\", train=True, fresques_transform=fresques_tnsf)\n",
|
|||
|
"test = FragmentDataset(fragments_path=\"training_data_random_shift_color\", train=False, fresques_transform=fresques_tnsf)\n",
|
|||
|
"\n",
|
|||
|
"bs = 4\n",
|
|||
|
"\n",
|
|||
|
"train_loader = DataLoader(train, batch_size=bs, num_workers=4, pin_memory=False, shuffle = True)\n",
|
|||
|
"test_loader = DataLoader(test, batch_size=bs, num_workers=4, pin_memory=False, shuffle = True)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"## Train"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 7,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"[EPOCH 2] Batch 4499/4500\n",
|
|||
|
"Temps par batch: 2.74\n",
|
|||
|
"Done with epoch 2\n",
|
|||
|
"Net sauvegardés dans ./trained_net/net_trainned_MB4_02-10_20-49_0003\n",
|
|||
|
"Poids sauvegardés dans ./trained_net/save_weights_MB4_02-10_20-49_0003\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"frag_size = 16\n",
|
|||
|
"psize = 4\n",
|
|||
|
"\n",
|
|||
|
"net = Net(frag_size, psize).cuda()\n",
|
|||
|
" \n",
|
|||
|
"optimizer = torch.optim.Adam(net.parameters(), lr=0.001)\n",
|
|||
|
"loss_func = torch.nn.MSELoss()\n",
|
|||
|
"\n",
|
|||
|
"num_epochs = 3\n",
|
|||
|
"\n",
|
|||
|
"loss_value = []\n",
|
|||
|
"para_value = []\n",
|
|||
|
"w_values = []\n",
|
|||
|
"\n",
|
|||
|
"time_old = time()\n",
|
|||
|
"\n",
|
|||
|
"for epoch in range(num_epochs):\n",
|
|||
|
" i=0\n",
|
|||
|
" for fragments, fresques, vts in train_loader:\n",
|
|||
|
" \n",
|
|||
|
" clear_output(wait=True)\n",
|
|||
|
" print(\"[EPOCH {}] Batch {}/{}\\nTemps par batch: {:.3}\".format(epoch,i,len(train_loader),time()-time_old))\n",
|
|||
|
" time_old = time()\n",
|
|||
|
"\n",
|
|||
|
" fragments = fragments.cuda()\n",
|
|||
|
" fresques = fresques.cuda()\n",
|
|||
|
" \n",
|
|||
|
" preds = net(fresques, fragments, True) \n",
|
|||
|
" optimizer.zero_grad()\n",
|
|||
|
" \n",
|
|||
|
" del(fragments)\n",
|
|||
|
" del(fresques)\n",
|
|||
|
" vts = vts.cuda()\n",
|
|||
|
" cost = loss_func(vts, preds)\n",
|
|||
|
" cost.backward()\n",
|
|||
|
" del(vts)\n",
|
|||
|
" optimizer.step()\n",
|
|||
|
" \n",
|
|||
|
" if i%10==0:\n",
|
|||
|
" w_values.append(net.conv1.weight.data.cpu().numpy())\n",
|
|||
|
" i+=1\n",
|
|||
|
"\n",
|
|||
|
" loss_value.append(cost.item())\n",
|
|||
|
" torch.cuda.empty_cache \n",
|
|||
|
" print('Done with epoch ', epoch)\n",
|
|||
|
" \n",
|
|||
|
"# Sauvegarder le réseau\n",
|
|||
|
"save_dir = './trained_net/'\n",
|
|||
|
"expe_id = 3\n",
|
|||
|
"net_filename = save_dir + \"net_trainned_MB{}_{}_{:04}\".format(bs,datetime.now().strftime(\"%m-%d_%H-%M\"),expe_id)\n",
|
|||
|
"save_net(net_filename,net)\n",
|
|||
|
"\n",
|
|||
|
"# Sauvegarder les poids\n",
|
|||
|
"poids_filename = save_dir + \"save_weights_MB{}_{}_{:04}\".format(bs,datetime.now().strftime(\"%m-%d_%H-%M\"),expe_id)\n",
|
|||
|
"with open(poids_filename,'wb') as f:\n",
|
|||
|
" pickle.dump(w_values,f)\n",
|
|||
|
"\n",
|
|||
|
"print(\"Net sauvegardés dans {}\".format(net_filename))\n",
|
|||
|
"print(\"Poids sauvegardés dans {}\".format(poids_filename))\n",
|
|||
|
" "
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 8,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"13500"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 8,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"len(loss_value)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 9,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"[<matplotlib.lines.Line2D at 0x7f6df46b8370>]"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 9,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYAAAAD4CAYAAADlwTGnAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8li6FKAAAgAElEQVR4nO3dd3wUdfoH8M+TRkloIfSWAAFE6ZFipUpT+Ymeh717nHq28xREucMGeqdnl0PU87ArKCpNpEhHAkgJNQkt1EQgtIS07++Pnd3M7s6WJJudnd3P+/Xyxe7Md3eeRJhnvl2UUiAiosgTZXYARERkDiYAIqIIxQRARBShmACIiCIUEwARUYSKMevCSUlJKjk52azLExFZ0vr16/OUUo0C8V2mJYDk5GSkp6ebdXkiIksSkX2B+i42ARERRSgmACKiCMUEQEQUoZgAiIgiFBMAEVGEYgIgIopQTABERBHKcglAKYWZ63NQUFRqdihERJZmuQSwJvs4/vr1Jjz34zazQyEisjTLJYCNB04AADIO5ZscCRGRtVkuAezLOwcAOHqq0ORIiIiszXIJIDZGAADFpdzKkoioKqyXAKJtIReXlJkcCRGRtVkuAcRpCeD0+RKTIyEisjbLJYCYaHG8VorNQERElWW5BBAdVR7yp2v3mxgJEZG1WS8BSHkNYNvhUyZGQkRkbdZLALqI2QRERFR5lksAUVH6PgATAyEisjjLJQB9E1AZMwARUaVZLwGwBkBEFBDWTgAmxkFEZHWWTgBsAiIiqjzLJQDR9QGc5WxgIqJKs1wCiNHVAHJOFHAoKBFRJVkuAfRoXd/xOuPQKUycnWFiNERE1mW5BJCSFO/0fsaafcg/V2xSNERE1mW5BCAQt2N5Z8+bEAkRkbVZLwG43/+JiKgSLJcAYqPdQ357caYJkRARWZtfCUBEhonIThHJFJFxBufricgPIrJJRDJE5K7Ah+rZtxsPBvNyRERhwWcCEJFoAO8AGA6gM4CbRKSzS7EHAWxTSnUD0B/AqyISF+BYiYgogPypAfQGkKmUylZKFQH4AsAolzIKQB2xzdJKAHAcAGdpERGFMH8SQAsAB3Tvc7Rjem8DuADAIQBbADyilHLbtV1E7heRdBFJz83NrWTIREQUCP4kAKNxN67Tb4cC+A1AcwDdAbwtInXdPqTUNKVUmlIqrVGjRhUOloiIAsefBJADoJXufUvYnvT17gIwS9lkAtgDoFNgQiQiourgTwJYByBVRFK0jt0xAL53KbMfwCAAEJEmADoCyA5koEREFFg+E4BSqgTAQwAWANgO4CulVIaIjBWRsVqx5wFcIiJbACwC8JRSKq+6gjYyf+uRYF6OiMjyYvwppJSaC2Cuy7GputeHAFwV2NAqZuwn65H10ggAznsGEBGRMcvNBPam3dNz0euFhWaHQURkCWGVAADgJFcGJSLyS9glACIi8g8TABFRhGICICKKUEwAREQRigmAiChCMQEQEUUoSyaAZ6923Y6AiIgqypIJ4J7LUryeX7E7D7mnzyO/gHMCiIg88WspCKu59YO1jtd7p4w0MRIiotBlyRoAERFVHRMAEVGEYgIgIopQTABERBGKCYCIKEIxARARRSgmACKiCGXZBNAqsZbZIRARWZplE0BKUoLZIRARWZplEwAREVUNEwARUYRiAiAiilBhnwBKy5TZIRARhaSwTwB/+3oTl4UmIjIQ9glg1saD6DbpJ7PDICIKOWGfAIiIyFjEJYAnvt6Eq/79i9lhEBGZLix3BDMyfXk2CotL8c36HLNDISIKCZZNAMMubIplu3L9Lv/CnO3VGA0RkfVYtgmof8dGZodARGRplk0AMVESkO9ZuO0onvthG5LHzcGWnPyAfCcRkRVYtgmocd2aAfme+/6X7nj9w+ZD6NKyXkC+l4go1Fm2BlBVby7ajYKiUqdjSnHWMBFFDsvWAKrqtYW7cPRUodMx3v+JKJJEbA0AAPLOnDc7BCIi00R0AoiJiugfn4giXETfAaNcRhKxBYiIIolfCUBEhonIThHJFJFxHsr0F5HfRCRDRCyx1kK0y0hS9gEQUSTx2QksItEA3gEwBEAOgHUi8r1SapuuTH0A7wIYppTaLyKNqyvgQHKtARARRRJ/agC9AWQqpbKVUkUAvgAwyqXMzQBmKaX2A4BS6lhgw6we0eKcAD5cucetzGdr92PWBq4fREThx58E0ALAAd37HO2YXgcADURkqYisF5Hbjb5IRO4XkXQRSc/N9X8dH0/+dGVb1KlR+ZGsX/uxMNzT327B419tqvQ1iIhClT8JwKidxLW1PAZALwAjAQwF8KyIdHD7kFLTlFJpSqm0Ro2qvpbP+OEXYMukoVX+HiKiSORPAsgB0Er3viWAQwZl5iulziql8gAsA9AtMCEG16GTBWaHQEQUFP4kgHUAUkUkRUTiAIwB8L1LmdkALheRGBGpDaAPAEuuv3zJlMU45jJDGACmLcviUhFEFFZ8JgClVAmAhwAsgO2m/pVSKkNExorIWK3MdgDzAWwG8CuA6UqprdUXtrM6NQO7osW6vSfcjr00dwf2Hz8X0OsQEZnJr3kASqm5SqkOSql2SqkXtWNTlVJTdWX+qZTqrJS6SCn1enUFbOSze/sG9Pu2Hz6F5HFzsDnnpNPxkjLWAIgofITFTGAV4Dm8by/JBAAsyDjifB3e/4kojIRFAqgu7yzJcjliywCFxaV4Zf4OFBaXun+IiMgimAAqwN4C9P6ybLy7NAsfrdzrOJd7+jxWZ/1uTmBERJUQFgkgWE0z9usUlZbZ/iwpc5y7/r1VuOn9NcEJhIgoAMIiAdSOiw7Kdex9DUYz4zhCiIisJiwSQGqTOkG5TlmZ8/tAdz4TEQVTWCQAAGjTsHa1X8NxwxeuIkpE1hc2CSAYt+QPV+x1es9hoURkZWGTAIJh5oYc5Jw4F5RkQ0RU3cImAUy9rVdQrqPvB8g4lI9P1+4LynWJiAItbBJAp6Z1cV0P120KAu/5OY6N0PDz9mOY8G3QljwiIgqowK6iZrJgNM0s3HYUDWrHBuFKRETVK2xqAACCkwEACHsBiCgMhFUCiAqB4ZmLdxzlpjJEZAlhlQDMv/0Dd/83HSPeXG52GEREPoVVAgiFGgAAnDxXbHYIREQ+hVUCsO8MNubiVj5KVs2+42er9fuJiIIhrBKAXbtGCdX6/Wuyjzu9H/XOymq9HhFRdQirBNCiQS0AQOO6NYJ63U0HTrod42YxRBTqwioB3NEvGR/emYZruzU3OxQ8+OkGs0MgIvIqrCaCRUUJBnZqYnYYAIBFO46ZHQIRkVdhVQPQe/G6i8wOgYgopIVtArilTxv8bWhHs8MgIgpZYZsAAKBhfJzZIUApBWWwcUD63uNIHjcHG/afMCEqIqIwTwChsF9L75cWoc9Li9yO/7IrFwCwYndesEMiIgIQZp3AoSj39Hmv57mrGBGZJaxrAE3r1jQ7BI9cF63YuP8EBr26FGfPl5gSDxFFnrBOAAM6NTY7BJ/sG81PnrcDWblnseVgvskREVGkCOsEAAA9W9c3OwRjIbJwHRFFrrBPAGLijXZVlu8OXnsfwPmSMpfjCh+v2ou8M977EIiIKivsE0C3lubVAG5+f63j9bHThR5H/GTlnnFbT2j3sTP4+/cZ+MtnG6s1RiKKXGGfAMaP6IQf/3KZ2WHg+vdW4dYP1uLzX/fjxv+sdjq34/Bpt/JFWo0gv4B7CxBR9Qj7YaCx0VG4qEU9s8PAgeO2bSLHz9oCAEhr08Bxjt0BRGSGsK8BhKp3l2YBAA6cOBcSW1kSUeSJmARgnxPQUtszIFTM2nCwQjWAsjLOHCOiwIiYBPDOLT0BANFRofi8XR6Tt5nBCzKOoO3Tc7HrqHufARFRRUVMAqgVGw0gNJdeGPvJesfr42eLUFJaZlhuQcYRAMAHy/cEJS4iCm8RkwDszSxloZgBdB78bAMm/bDNa5kv0w9gCTecIaIq8isBiMgwEdkpIpkiMs5LuYtFpFREbghciIEV4vd/AMCMNft8ljmUXxCESIgonPlMACISDeA
|
|||
|
"text/plain": [
|
|||
|
"<Figure size 432x288 with 1 Axes>"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {
|
|||
|
"needs_background": "light"
|
|||
|
},
|
|||
|
"output_type": "display_data"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"plt.plot(loss_value)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 5,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"ename": "NameError",
|
|||
|
"evalue": "name 'net' is not defined",
|
|||
|
"output_type": "error",
|
|||
|
"traceback": [
|
|||
|
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
|||
|
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
|
|||
|
"\u001b[0;32m<ipython-input-5-47d0cd6ed22d>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mfile_path\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"./net_trainned6000_MB_102\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0msave_net\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfile_path\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mnet\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
|
|||
|
"\u001b[0;31mNameError\u001b[0m: name 'net' is not defined"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"file_path=\"./net_trainned6000_MB_102\"\n",
|
|||
|
"save_net(file_path,net)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": []
|
|||
|
}
|
|||
|
],
|
|||
|
"metadata": {
|
|||
|
"kernelspec": {
|
|||
|
"display_name": "Python 3",
|
|||
|
"language": "python",
|
|||
|
"name": "python3"
|
|||
|
},
|
|||
|
"language_info": {
|
|||
|
"codemirror_mode": {
|
|||
|
"name": "ipython",
|
|||
|
"version": 3
|
|||
|
},
|
|||
|
"file_extension": ".py",
|
|||
|
"mimetype": "text/x-python",
|
|||
|
"name": "python",
|
|||
|
"nbconvert_exporter": "python",
|
|||
|
"pygments_lexer": "ipython3",
|
|||
|
"version": "3.8.5"
|
|||
|
}
|
|||
|
},
|
|||
|
"nbformat": 4,
|
|||
|
"nbformat_minor": 4
|
|||
|
}
|