{ "cells": [ { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import matplotlib\n", "%matplotlib notebook\n", "from matplotlib import pyplot as plt\n", "from mpl_toolkits.mplot3d import Axes3D\n", "from matplotlib import cm\n", "import pickle\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "import numpy as np\n", "import cv2\n", "from random import randint" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [], "source": [ "#Les fonctions dans ce bloc ne sont pas utilisées par le réseau, mais certaines fonctions d'outils\n", "\n", "# Cette fonction convertit l'image en variable de type Tensor.\n", "# Toutes les données de calcul du réseau sont de type Tensor\n", "# Img.shape=[Height,Width,Channel]\n", "# Tensor.shape=[Batch,Channel,Height,Width]\n", "def img2tensor(im):\n", " im=np.array(im,dtype=\"float32\")\n", " tensor_cv = torch.from_numpy(np.transpose(im, (2, 0, 1)))\n", " im_tensor=tensor_cv.unsqueeze(0)\n", " return im_tensor\n", "\n", "# Trouvez les coordonnées de la valeur maximale dans une carte de corrélation\n", "# x,y=show_coordonnee(carte de corrélation)\n", "def show_coordonnee(position_pred):\n", " map_corre=position_pred.squeeze().detach().cpu().numpy()\n", " h,w=map_corre.shape\n", " max_value=map_corre.max()\n", " coordonnee=np.where(map_corre==max_value)\n", " return coordonnee[0].mean(),coordonnee[1].mean()\n", " \n", "def load_net(file_path): \n", " pkl_file = open(file_path, 'rb')\n", " net= pickle.load(pkl_file)\n", " pkl_file.close()\n", " return net\n", "\n", "def kernel_add_ini(n,m):\n", " input_canal=int(n*m)\n", " output_canal=int(n/2)*int(m/2)\n", " for i in range(int(n/2)):\n", " for j in range(int(m/2)):\n", " kernel_add=np.zeros([1,input_canal],dtype='float32')\n", " kernel_add[0,i*2*m+j*2]=1\n", " kernel_add[0,i*2*m+j*2+1]=1\n", " kernel_add[0,(i*2+1)*m+j*2]=1\n", " kernel_add[0,(i*2+1)*m+j*2+1]=1\n", " if i==0 and j==0:\n", " add=torch.from_numpy(kernel_add.reshape(1,input_canal,1,1))\n", " else:\n", " add_=torch.from_numpy(kernel_add.reshape(1,input_canal,1,1))\n", " add=torch.cat((add,add_),0)\n", " return torch.nn.Parameter(add,requires_grad=False) \n", "\n", "def kernel_shift_ini(n,m):\n", " input_canal=int(n*m)\n", " output_canal=int(n*m)\n", " \n", " kernel_shift=torch.zeros([output_canal,input_canal,3,3])\n", " \n", " array_0=np.array([[1,0,0],[0,0,0],[0,0,0]],dtype='float32')\n", " array_1=np.array([[0,0,1],[0,0,0],[0,0,0]],dtype='float32')\n", " array_2=np.array([[0,0,0],[0,0,0],[1,0,0]],dtype='float32')\n", " array_3=np.array([[0,0,0],[0,0,0],[0,0,1]],dtype='float32')\n", " \n", " kernel_shift_0=torch.from_numpy(array_0)\n", " kernel_shift_1=torch.from_numpy(array_1)\n", " kernel_shift_2=torch.from_numpy(array_2)\n", " kernel_shift_3=torch.from_numpy(array_3)\n", " \n", " \n", " for i in range(n):\n", " for j in range(m):\n", " if i==0 and j==0:\n", " kernel_shift[0,0,:]=kernel_shift_0\n", " else:\n", " if i%2==0 and j%2==0:\n", " kernel_shift[i*m+j,i*m+j,:]=kernel_shift_0\n", " if i%2==0 and j%2==1:\n", " kernel_shift[i*m+j,i*m+j,:]=kernel_shift_1\n", " if i%2==1 and j%2==0:\n", " kernel_shift[i*m+j,i*m+j,:]=kernel_shift_2\n", " if i%2==1 and j%2==1:\n", " kernel_shift[i*m+j,i*m+j,:]=kernel_shift_3\n", " \n", " return torch.nn.Parameter(kernel_shift,requires_grad=False) \n", "\n", "# Trouvez le petit patch(4 * 4) dans la n ème ligne et la m ème colonne du patch d'entrée\n", "# Ceci est utilisé pour calculer la convolution et obtenir la carte de corrélation\n", "def get_patch(fragment,psize,n,m):\n", " return fragment[:,:,n*psize:(n+1)*psize,m*psize:(m+1)*psize]\n", "###################################################################################################################\n", "\n", "class Net(nn.Module):\n", " def __init__(self,frag_size,psize):\n", " super(Net, self).__init__()\n", " \n", " h_fr=frag_size\n", " w_fr=frag_size\n", " \n", " n=int(h_fr/psize) #n*m patches\n", " m=int(w_fr/psize)\n", " \n", " self.conv1 = nn.Conv2d(3,8,kernel_size=3,stride=1,padding=1)\n", " #self.conv1.weight=ini()\n", " self.Relu = nn.ReLU(inplace=True)\n", " self.maxpooling=nn.MaxPool2d(3,stride=2, padding=1)\n", " \n", " self.shift1=nn.Conv2d(n*m,n*m,kernel_size=3,stride=1,padding=1)\n", " self.shift1.weight=kernel_shift_ini(n,m)\n", " self.add1 = nn.Conv2d(n*m,int(n/2)*int(m/2),kernel_size=1,stride=1,padding=0)\n", " self.add1.weight=kernel_add_ini(n,m)\n", " \n", " n=int(n/2)\n", " m=int(m/2)\n", " if n>=2 and m>=2:\n", " self.shift2=nn.Conv2d(n*m,n*m,kernel_size=3,stride=1,padding=1)\n", " self.shift2.weight=kernel_shift_ini(n,m)\n", " self.add2 = nn.Conv2d(n*m,int(n/2)*int(m/2),kernel_size=1,stride=1,padding=0)\n", " self.add2.weight=kernel_add_ini(n,m)\n", " \n", " n=int(n/2)\n", " m=int(m/2)\n", " if n>=2 and m>=2:\n", " self.shift3=nn.Conv2d(n*m,n*m,kernel_size=3,stride=1,padding=1)\n", " self.shift3.weight=kernel_shift_ini(n,m)\n", " self.add3 = nn.Conv2d(n*m,int(n/2)*int(m/2),kernel_size=1,stride=1,padding=0)\n", " self.add3.weight=kernel_add_ini(n,m)\n", " \n", " \n", " def get_descripteur(self,img,using_cuda):\n", " descripteur_img=self.Relu(self.conv1(img))\n", " b,c,h,w=descripteur_img.shape\n", " couche_constante=0.5*torch.ones([1,1,h,w])\n", " if using_cuda:\n", " couche_constante=couche_constante.cuda()\n", " descripteur_img=torch.cat((descripteur_img,couche_constante),1)\n", " descripteur_img_norm=descripteur_img/torch.norm(descripteur_img,dim=1)\n", " return descripteur_img\n", " \n", " def forward(self,img,frag,using_cuda):\n", " psize=4\n", " \n", " descripteur_input1=self.get_descripteur(img,using_cuda)\n", " descripteur_input2=self.get_descripteur(frag,using_cuda)\n", " \n", " b,c,h,w=frag.shape\n", " n=int(h/psize)\n", " m=int(w/psize)\n", " \n", " for i in range(n):\n", " for j in range(m):\n", " if i==0 and j==0:\n", " map_corre=F.conv2d(descripteur_input1,get_patch(descripteur_input2,psize,i,j),padding=2)\n", " else:\n", " a=F.conv2d(descripteur_input1,get_patch(descripteur_input2,psize,i,j),padding=2)\n", " map_corre=torch.cat((map_corre,a),1)\n", " #shift\n", " map_corre=self.maxpooling(map_corre)\n", " map_corre=self.shift1(map_corre)\n", " map_corre=self.add1(map_corre)\n", " c1 = map_corre.data.numpy()\n", " \n", " \n", " n=int(n/2)\n", " m=int(m/2)\n", " if n>=2 and m>=2:\n", " map_corre=self.maxpooling(map_corre)\n", " map_corre=self.shift2(map_corre)\n", " map_corre=self.add2(map_corre)\n", " \n", " c2 = map_corre.data.numpy()\n", " \n", " \n", " n=int(n/2)\n", " m=int(m/2)\n", " if n>=2 and m>=2:\n", " map_corre=self.maxpooling(map_corre)\n", " map_corre=self.shift3(map_corre)\n", " map_corre=self.add3(map_corre)\n", " \n", " \n", " b,c,h,w=map_corre.shape\n", " map_corre=map_corre/(map_corre.max())\n", " #map_corre=(F.softmax(map_corre.reshape(1,1,h*w,1),dim=2)).reshape(b,c,h,w)\n", " return c1,c2,map_corre.data.numpy()" ] }, { "cell_type": "code", "execution_count": 28, 