{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import matplotlib\n", "%matplotlib notebook\n", "from matplotlib import pyplot as plt\n", "import pickle\n", "import torch.nn as nn\n", "import numpy as np" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "def load_net(file_path): \n", " pkl_file = open(file_path, 'rb')\n", " net= pickle.load(pkl_file)\n", " pkl_file.close()\n", " return net\n", "\n", "def ini():\n", " kernel=torch.zeros([8,3,3,3])\n", " array_0=np.array([[1,2,1],[0,0,0],[-1,-2,-1]],dtype='float32')\n", " array_1=np.array([[2,1,0],[1,0,-1],[0,-1,-2]],dtype='float32')\n", " array_2=np.array([[1,0,-1],[2,0,-2],[1,0,-1]],dtype='float32')\n", " array_3=np.array([[0,-1,-2],[1,0,-1],[2,1,0]],dtype='float32')\n", " array_4=np.array([[-1,-2,-1],[0,0,0],[1,2,1]],dtype='float32')\n", " array_5=np.array([[-2,-1,0],[-1,0,1],[0,1,2]],dtype='float32')\n", " array_6=np.array([[-1,0,1],[-2,0,2],[-1,0,1]],dtype='float32')\n", " array_7=np.array([[0,1,2],[-1,0,1],[-2,-1,0]],dtype='float32')\n", " for i in range(3):\n", " kernel[0,i,:]=torch.from_numpy(array_0)\n", " kernel[1,i,:]=torch.from_numpy(array_1)\n", " kernel[2,i,:]=torch.from_numpy(array_2)\n", " kernel[3,i,:]=torch.from_numpy(array_3)\n", " kernel[4,i,:]=torch.from_numpy(array_4)\n", " kernel[5,i,:]=torch.from_numpy(array_5)\n", " kernel[6,i,:]=torch.from_numpy(array_6)\n", " kernel[7,i,:]=torch.from_numpy(array_7)\n", " return torch.nn.Parameter(kernel,requires_grad=True) \n", "\n", "class Net(nn.Module):\n", " def __init__(self,frag_size,psize):\n", " super(Net, self).__init__()\n", " \n", " h_fr=frag_size\n", " w_fr=frag_size\n", " \n", " n=int(h_fr/psize) # n*m patches dans le patch d'entrée\n", " m=int(w_fr/psize)\n", " \n", " self.conv1 = nn.Conv2d(3,4,kernel_size=3,stride=1,padding=1)\n", " # Si vous souhaitez initialiser Conv1 avec les poids de DeepMatch, exécutez la ligne suivante\n", " # self.conv1.weight=ini()\n", " self.Relu = nn.ReLU(inplace=True)\n", " self.maxpooling=nn.MaxPool2d(3,stride=2, padding=1)\n", " \n", " self.shift1=nn.Conv2d(n*m,n*m,kernel_size=3,stride=1,padding=1)\n", " self.shift1.weight=kernel_shift_ini(n,m)\n", " self.add1 = nn.Conv2d(n*m,int(n/2)*int(m/2),kernel_size=1,stride=1,padding=0)\n", " self.add1.weight=kernel_add_ini(n,m)\n", " \n", " n=int(n/2)\n", " m=int(m/2)\n", " if n>=2 and m>=2:# Si n=m=1,Notre réseau n'a plus besoin de plus de couches pour agréger les cartes de corrélation\n", " self.shift2=nn.Conv2d(n*m,n*m,kernel_size=3,stride=1,padding=1)\n", " self.shift2.weight=kernel_shift_ini(n,m)\n", " self.add2 = nn.Conv2d(n*m,int(n/2)*int(m/2),kernel_size=1,stride=1,padding=0)\n", " self.add2.weight=kernel_add_ini(n,m)\n", " \n", " n=int(n/2)\n", " m=int(m/2)\n", " if n>=2 and m>=2:\n", " self.shift3=nn.Conv2d(n*m,n*m,kernel_size=3,stride=1,padding=1)\n", " self.shift3.weight=kernel_shift_ini(n,m)\n", " self.add3 = nn.Conv2d(n*m,int(n/2)*int(m/2),kernel_size=1,stride=1,padding=0)\n", " self.add3.weight=kernel_add_ini(n,m)\n", " \n", " def get_descripteur(self,img,using_cuda):\n", " # Utilisez Conv1 pour calculer le descripteur,\n", " descripteur_img=self.Relu(self.conv1(img))\n", " b,c,h,w=descripteur_img.shape\n", " couche_constante=0.5*torch.ones([1,1,h,w])\n", " if using_cuda:\n", " couche_constante=couche_constante.cuda()\n", " # Ajouter une couche constante pour éviter la division par 0 lors de la normalisation\n", " descripteur_img=torch.cat((descripteur_img,couche_constante),1)\n", " # la normalisation\n", " descripteur_img_norm=descripteur_img/torch.norm(descripteur_img,dim=1)\n", " return descripteur_img_norm\n", " \n", " def forward(self,img,frag,using_cuda):\n", " psize=4\n", " # Utilisez Conv1 pour calculer le descripteur,\n", " descripteur_input1=self.get_descripteur(img,using_cuda)\n", " descripteur_input2=self.get_descripteur(frag,using_cuda)\n", " \n", " b,c,h,w=frag.shape\n", " n=int(h/psize)\n", " m=int(w/psize)\n", " \n", " #######################################\n", " # Calculer la carte de corrélation par convolution pour les n*m patchs plus petit.\n", " for i in range(n):\n", " for j in range(m):\n", " if i==0 and j==0:\n", " map_corre=F.conv2d(descripteur_input1,get_patch(descripteur_input2,psize,i,j),padding=2)\n", " else:\n", " a=F.conv2d(descripteur_input1,get_patch(descripteur_input2,psize,i,j),padding=2)\n", " map_corre=torch.cat((map_corre,a),1)\n", " ########################################\n", " # Étape de polymérisation\n", " map_corre=self.maxpooling(map_corre)\n", " map_corre=self.shift1(map_corre)\n", " map_corre=self.add1(map_corre)\n", " \n", " #########################################\n", " # Répétez l'étape d'agrégation jusqu'à obtenir le graphique de corrélation du patch d'entrée\n", " n=int(n/2)\n", " m=int(m/2)\n", " if n>=2 and m>=2:\n", " map_corre=self.maxpooling(map_corre)\n", " map_corre=self.shift2(map_corre)\n", " map_corre=self.add2(map_corre)\n", " \n", " \n", " n=int(n/2)\n", " m=int(m/2)\n", " if n>=2 and m>=2:\n", " map_corre=self.maxpooling(map_corre)\n", " map_corre=self.shift3(map_corre)\n", " map_corre=self.add3(map_corre)\n", " \n", " \n", " b,c,h,w=map_corre.shape\n", " # Normalisation de la division par maximum\n", " map_corre=map_corre/(map_corre.max())\n", " # Normalisation SoftMax\n", " #map_corre=(F.softmax(map_corre.reshape(1,1,h*w,1),dim=2)).reshape(b,c,h,w)\n", " return map_corre\n", "\n", "def normalize(a):\n", " return((a - np.min(a))/np.ptp(a))\n", "\n", "def carte(w,save_filename,title):\n", " \n", " fig,axs = plt.subplots(3,8,figsize=(15,8))\n", " \n", " max_ptp = 0\n", " ref_im = None\n", " \n", " for i in range(3):\n", " for j in range(8):\n", " im = axs[i,j].imshow(normalize(w[j,i,:,:]))\n", " \n", " if i == 0:\n", " axs[i,j].set_title('Couche {}'.format(j))\n", " if j == 0:\n", " axs[i,j].set_ylabel('Channel {}'.format(i+1))\n", " \n", " axs[i,j].set_xticks([])\n", " axs[i,j].set_yticks([])\n", "\n", " fig.subplots_adjust(right=0.8)\n", " cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7])\n", " fig.colorbar(im, cax=cbar_ax)\n", " \n", " \n", " \n", " fig.suptitle(\"{}\".format(title),fontsize=16)\n", " \n", " if save_filename != None:\n", " plt.savefig(save_filename)\n", " #plt.close()\n", "\n", "def carte4(w,save_filename,title):\n", " \n", " fig,axs = plt.subplots(3,4,figsize=(15,8))\n", " \n", " max_ptp = 0\n", " ref_im = None\n", " \n", " for i in range(3):\n", " for j in range(4):\n", " im = axs[i,j].imshow(normalize(w[j,i,:,:]))\n", " \n", " if i == 0:\n", " axs[i,j].set_title('Couche {}'.format(j))\n", " if j == 0:\n", " axs[i,j].set_ylabel('Channel {}'.format(i+1))\n", " \n", " axs[i,j].set_xticks([])\n", " axs[i,j].set_yticks([])\n", "\n", " fig.subplots_adjust(right=0.8)\n", " cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7])\n", " fig.colorbar(im, cax=cbar_ax)\n", " \n", " \n", " \n", " fig.suptitle(\"{}\".format(title),fontsize=16)\n", " \n", " if save_filename != None:\n", " plt.savefig(save_filename)\n", " #plt.close()\n", " \n", "def carte_32(w,save_filename,title):\n", " \n", " fig,axs = plt.subplots(3,4,figsize=(15,8))\n", " \n", " max_ptp = 0\n", " ref_im = None\n", " \n", " for i in range(3):\n", " for j in range(4):\n", " #im = axs[i,j].imshow(normalize(w[j,i,:,:]))\n", " im = axs[i,j].imshow(w[j,i,:,:],cmap='coolwarm')\n", " \n", " for a in range(3):\n", " for b in range(3):\n", " text = axs[i,j].text(b, a, round(w[j, i, a, b],2),\n", " ha=\"center\", va=\"center\", color=\"w\")\n", " \n", " if i == 0:\n", " axs[i,j].set_title('Couche {}'.format(j))\n", " if j == 0:\n", " axs[i,j].set_ylabel('Channel {}'.format(i+1))\n", " \n", " axs[i,j].set_xticks([])\n", " axs[i,j].set_yticks([])\n", "\n", " fig.subplots_adjust(right=0.8)\n", " cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7])\n", " \n", " mini,maxi = np.min(w),np.max(w)\n", " \n", " cmap = matplotlib.cm.coolwarm\n", " norm = matplotlib.colors.Normalize(vmin=mini, vmax=maxi)\n", " \n", " fig.colorbar(matplotlib.cm.ScalarMappable(norm=norm, cmap=cmap), cax=cbar_ax,orientation='vertical')\n", " \n", " \n", " \n", " fig.suptitle(\"{}\".format(title),fontsize=16)\n", " \n", " if save_filename != None:\n", " plt.savefig(save_filename)\n", " #plt.close()" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "scrolled": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(8, 3, 3, 3)\n" ] }, { "data": { "application/javascript": [ "/* Put everything inside the global mpl namespace */\n", "window.mpl = {};\n", "\n", "\n", "mpl.get_websocket_type = function() {\n", " if (typeof(WebSocket) !== 'undefined') {\n", " return WebSocket;\n", " } else if (typeof(MozWebSocket) !== 'undefined') {\n", " return MozWebSocket;\n", " } else {\n", " alert('Your browser does not have WebSocket support. ' +\n", " 'Please try Chrome, Safari or Firefox ≥ 6. ' +\n", " 'Firefox 4 and 5 are also supported but you ' +\n", " 'have to enable WebSockets in about:config.');\n", " };\n", "}\n", "\n", "mpl.figure = function(figure_id, websocket, ondownload, parent_element) {\n", " this.id = figure_id;\n", "\n", " this.ws = websocket;\n", "\n", " this.supports_binary = (this.ws.binaryType != undefined);\n", "\n", " if (!this.supports_binary) {\n", " var warnings = document.getElementById(\"mpl-warnings\");\n", " if (warnings) {\n", " warnings.style.display = 'block';\n", " warnings.textContent = (\n", " \"This browser does not support binary websocket messages. \" +\n", " \"Performance may be slow.\");\n", " }\n", " }\n", "\n", " this.imageObj = new Image();\n", "\n", " this.context = undefined;\n", " this.message = undefined;\n", " this.canvas = undefined;\n", " this.rubberband_canvas = undefined;\n", " this.rubberband_context = undefined;\n", " this.format_dropdown = undefined;\n", "\n", " this.image_mode = 'full';\n", "\n", " this.root = $('
');\n", " this._root_extra_style(this.root)\n", " this.root.attr('style', 'display: inline-block');\n", "\n", " $(parent_element).append(this.root);\n", "\n", " this._init_header(this);\n", " this._init_canvas(this);\n", " this._init_toolbar(this);\n", "\n", " var fig = this;\n", "\n", " this.waiting = false;\n", "\n", " this.ws.onopen = function () {\n", " fig.send_message(\"supports_binary\", {value: fig.supports_binary});\n", " fig.send_message(\"send_image_mode\", {});\n", " if (mpl.ratio != 1) {\n", " fig.send_message(\"set_dpi_ratio\", {'dpi_ratio': mpl.ratio});\n", " }\n", " fig.send_message(\"refresh\", {});\n", " }\n", "\n", " this.imageObj.onload = function() {\n", " if (fig.image_mode == 'full') {\n", " // Full images could contain transparency (where diff images\n", " // almost always do), so we need to clear the canvas so that\n", " // there is no ghosting.\n", " fig.context.clearRect(0, 0, fig.canvas.width, fig.canvas.height);\n", " }\n", " fig.context.drawImage(fig.imageObj, 0, 0);\n", " };\n", "\n", " this.imageObj.onunload = function() {\n", " fig.ws.close();\n", " }\n", "\n", " this.ws.onmessage = this._make_on_message_function(this);\n", "\n", " this.ondownload = ondownload;\n", "}\n", "\n", "mpl.figure.prototype._init_header = function() {\n", " var titlebar = $(\n", " '
');\n", " var titletext = $(\n", " '
');\n", " titlebar.append(titletext)\n", " this.root.append(titlebar);\n", " this.header = titletext[0];\n", "}\n", "\n", "\n", "\n", "mpl.figure.prototype._canvas_extra_style = function(canvas_div) {\n", "\n", "}\n", "\n", "\n", "mpl.figure.prototype._root_extra_style = function(canvas_div) {\n", "\n", "}\n", "\n", "mpl.figure.prototype._init_canvas = function() {\n", " var fig = this;\n", "\n", " var canvas_div = $('
');\n", "\n", " canvas_div.attr('style', 'position: relative; clear: both; outline: 0');\n", "\n", " function canvas_keyboard_event(event) {\n", " return fig.key_event(event, event['data']);\n", " }\n", "\n", " canvas_div.keydown('key_press', canvas_keyboard_event);\n", " canvas_div.keyup('key_release', canvas_keyboard_event);\n", " this.canvas_div = canvas_div\n", " this._canvas_extra_style(canvas_div)\n", " this.root.append(canvas_div);\n", "\n", " var canvas = $('');\n", " canvas.addClass('mpl-canvas');\n", " canvas.attr('style', \"left: 0; top: 0; z-index: 0; outline: 0\")\n", "\n", " this.canvas = canvas[0];\n", " this.context = canvas[0].getContext(\"2d\");\n", "\n", " var backingStore = this.context.backingStorePixelRatio ||\n", "\tthis.context.webkitBackingStorePixelRatio ||\n", "\tthis.context.mozBackingStorePixelRatio ||\n", "\tthis.context.msBackingStorePixelRatio ||\n", "\tthis.context.oBackingStorePixelRatio ||\n", "\tthis.context.backingStorePixelRatio || 1;\n", "\n", " mpl.ratio = (window.devicePixelRatio || 1) / backingStore;\n", "\n", " var rubberband = $('');\n", " rubberband.attr('style', \"position: absolute; left: 0; top: 0; z-index: 1;\")\n", "\n", " var pass_mouse_events = true;\n", "\n", " canvas_div.resizable({\n", " start: function(event, ui) {\n", " pass_mouse_events = false;\n", " },\n", " resize: function(event, ui) {\n", " fig.request_resize(ui.size.width, ui.size.height);\n", " },\n", " stop: function(event, ui) {\n", " pass_mouse_events = true;\n", " fig.request_resize(ui.size.width, ui.size.height);\n", " },\n", " });\n", "\n", " function mouse_event_fn(event) {\n", " if (pass_mouse_events)\n", " return fig.mouse_event(event, event['data']);\n", " }\n", "\n", " rubberband.mousedown('button_press', mouse_event_fn);\n", " rubberband.mouseup('button_release', mouse_event_fn);\n", " // Throttle sequential mouse events to 1 every 20ms.\n", " rubberband.mousemove('motion_notify', mouse_event_fn);\n", "\n", " rubberband.mouseenter('figure_enter', mouse_event_fn);\n", " rubberband.mouseleave('figure_leave', mouse_event_fn);\n", "\n", " canvas_div.on(\"wheel\", function (event) {\n", " event = event.originalEvent;\n", " event['data'] = 'scroll'\n", " if (event.deltaY < 0) {\n", " event.step = 1;\n", " } else {\n", " event.step = -1;\n", " }\n", " mouse_event_fn(event);\n", " });\n", "\n", " canvas_div.append(canvas);\n", " canvas_div.append(rubberband);\n", "\n", " this.rubberband = rubberband;\n", " this.rubberband_canvas = rubberband[0];\n", " this.rubberband_context = rubberband[0].getContext(\"2d\");\n", " this.rubberband_context.strokeStyle = \"#000000\";\n", "\n", " this._resize_canvas = function(width, height) {\n", " // Keep the size of the canvas, canvas container, and rubber band\n", " // canvas in synch.\n", " canvas_div.css('width', width)\n", " canvas_div.css('height', height)\n", "\n", " canvas.attr('width', width * mpl.ratio);\n", " canvas.attr('height', height * mpl.ratio);\n", " canvas.attr('style', 'width: ' + width + 'px; height: ' + height + 'px;');\n", "\n", " rubberband.attr('width', width);\n", " rubberband.attr('height', height);\n", " }\n", "\n", " // Set the figure to an initial 600x600px, this will subsequently be updated\n", " // upon first draw.\n", " this._resize_canvas(600, 600);\n", "\n", " // Disable right mouse context menu.\n", " $(this.rubberband_canvas).bind(\"contextmenu\",function(e){\n", " return false;\n", " });\n", "\n", " function set_focus () {\n", " canvas.focus();\n", " canvas_div.focus();\n", " }\n", "\n", " window.setTimeout(set_focus, 100);\n", "}\n", "\n", "mpl.figure.prototype._init_toolbar = function() {\n", " var fig = this;\n", "\n", " var nav_element = $('
');\n", " nav_element.attr('style', 'width: 100%');\n", " this.root.append(nav_element);\n", "\n", " // Define a callback function for later on.\n", " function toolbar_event(event) {\n", " return fig.toolbar_button_onclick(event['data']);\n", " }\n", " function toolbar_mouse_event(event) {\n", " return fig.toolbar_button_onmouseover(event['data']);\n", " }\n", "\n", " for(var toolbar_ind in mpl.toolbar_items) {\n", " var name = mpl.toolbar_items[toolbar_ind][0];\n", " var tooltip = mpl.toolbar_items[toolbar_ind][1];\n", " var image = mpl.toolbar_items[toolbar_ind][2];\n", " var method_name = mpl.toolbar_items[toolbar_ind][3];\n", "\n", " if (!name) {\n", " // put a spacer in here.\n", " continue;\n", " }\n", " var button = $('