{ "cells": [ { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import matplotlib\n", "%matplotlib notebook\n", "from matplotlib import pyplot as plt\n", "from mpl_toolkits.mplot3d import Axes3D\n", "from matplotlib import cm\n", "import pickle\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "import numpy as np\n", "import cv2\n", "from random import randint" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [], "source": [ "#Les fonctions dans ce bloc ne sont pas utilisées par le réseau, mais certaines fonctions d'outils\n", "\n", "# Cette fonction convertit l'image en variable de type Tensor.\n", "# Toutes les données de calcul du réseau sont de type Tensor\n", "# Img.shape=[Height,Width,Channel]\n", "# Tensor.shape=[Batch,Channel,Height,Width]\n", "def img2tensor(im):\n", " im=np.array(im,dtype=\"float32\")\n", " tensor_cv = torch.from_numpy(np.transpose(im, (2, 0, 1)))\n", " im_tensor=tensor_cv.unsqueeze(0)\n", " return im_tensor\n", "\n", "# Trouvez les coordonnées de la valeur maximale dans une carte de corrélation\n", "# x,y=show_coordonnee(carte de corrélation)\n", "def show_coordonnee(position_pred):\n", " map_corre=position_pred.squeeze().detach().cpu().numpy()\n", " h,w=map_corre.shape\n", " max_value=map_corre.max()\n", " coordonnee=np.where(map_corre==max_value)\n", " return coordonnee[0].mean(),coordonnee[1].mean()\n", " \n", "def load_net(file_path): \n", " pkl_file = open(file_path, 'rb')\n", " net= pickle.load(pkl_file)\n", " pkl_file.close()\n", " return net\n", "\n", "def kernel_add_ini(n,m):\n", " input_canal=int(n*m)\n", " output_canal=int(n/2)*int(m/2)\n", " for i in range(int(n/2)):\n", " for j in range(int(m/2)):\n", " kernel_add=np.zeros([1,input_canal],dtype='float32')\n", " kernel_add[0,i*2*m+j*2]=1\n", " kernel_add[0,i*2*m+j*2+1]=1\n", " kernel_add[0,(i*2+1)*m+j*2]=1\n", " kernel_add[0,(i*2+1)*m+j*2+1]=1\n", " if i==0 and j==0:\n", " add=torch.from_numpy(kernel_add.reshape(1,input_canal,1,1))\n", " else:\n", " add_=torch.from_numpy(kernel_add.reshape(1,input_canal,1,1))\n", " add=torch.cat((add,add_),0)\n", " return torch.nn.Parameter(add,requires_grad=False) \n", "\n", "def kernel_shift_ini(n,m):\n", " input_canal=int(n*m)\n", " output_canal=int(n*m)\n", " \n", " kernel_shift=torch.zeros([output_canal,input_canal,3,3])\n", " \n", " array_0=np.array([[1,0,0],[0,0,0],[0,0,0]],dtype='float32')\n", " array_1=np.array([[0,0,1],[0,0,0],[0,0,0]],dtype='float32')\n", " array_2=np.array([[0,0,0],[0,0,0],[1,0,0]],dtype='float32')\n", " array_3=np.array([[0,0,0],[0,0,0],[0,0,1]],dtype='float32')\n", " \n", " kernel_shift_0=torch.from_numpy(array_0)\n", " kernel_shift_1=torch.from_numpy(array_1)\n", " kernel_shift_2=torch.from_numpy(array_2)\n", " kernel_shift_3=torch.from_numpy(array_3)\n", " \n", " \n", " for i in range(n):\n", " for j in range(m):\n", " if i==0 and j==0:\n", " kernel_shift[0,0,:]=kernel_shift_0\n", " else:\n", " if i%2==0 and j%2==0:\n", " kernel_shift[i*m+j,i*m+j,:]=kernel_shift_0\n", " if i%2==0 and j%2==1:\n", " kernel_shift[i*m+j,i*m+j,:]=kernel_shift_1\n", " if i%2==1 and j%2==0:\n", " kernel_shift[i*m+j,i*m+j,:]=kernel_shift_2\n", " if i%2==1 and j%2==1:\n", " kernel_shift[i*m+j,i*m+j,:]=kernel_shift_3\n", " \n", " return torch.nn.Parameter(kernel_shift,requires_grad=False) \n", "\n", "# Trouvez le petit patch(4 * 4) dans la n ème ligne et la m ème colonne du patch d'entrée\n", "# Ceci est utilisé pour calculer la convolution et obtenir la carte de corrélation\n", "def get_patch(fragment,psize,n,m):\n", " return fragment[:,:,n*psize:(n+1)*psize,m*psize:(m+1)*psize]\n", "###################################################################################################################\n", "\n", "class Net(nn.Module):\n", " def __init__(self,frag_size,psize):\n", " super(Net, self).__init__()\n", " \n", " h_fr=frag_size\n", " w_fr=frag_size\n", " \n", " n=int(h_fr/psize) #n*m patches\n", " m=int(w_fr/psize)\n", " \n", " self.conv1 = nn.Conv2d(3,8,kernel_size=3,stride=1,padding=1)\n", " #self.conv1.weight=ini()\n", " self.Relu = nn.ReLU(inplace=True)\n", " self.maxpooling=nn.MaxPool2d(3,stride=2, padding=1)\n", " \n", " self.shift1=nn.Conv2d(n*m,n*m,kernel_size=3,stride=1,padding=1)\n", " self.shift1.weight=kernel_shift_ini(n,m)\n", " self.add1 = nn.Conv2d(n*m,int(n/2)*int(m/2),kernel_size=1,stride=1,padding=0)\n", " self.add1.weight=kernel_add_ini(n,m)\n", " \n", " n=int(n/2)\n", " m=int(m/2)\n", " if n>=2 and m>=2:\n", " self.shift2=nn.Conv2d(n*m,n*m,kernel_size=3,stride=1,padding=1)\n", " self.shift2.weight=kernel_shift_ini(n,m)\n", " self.add2 = nn.Conv2d(n*m,int(n/2)*int(m/2),kernel_size=1,stride=1,padding=0)\n", " self.add2.weight=kernel_add_ini(n,m)\n", " \n", " n=int(n/2)\n", " m=int(m/2)\n", " if n>=2 and m>=2:\n", " self.shift3=nn.Conv2d(n*m,n*m,kernel_size=3,stride=1,padding=1)\n", " self.shift3.weight=kernel_shift_ini(n,m)\n", " self.add3 = nn.Conv2d(n*m,int(n/2)*int(m/2),kernel_size=1,stride=1,padding=0)\n", " self.add3.weight=kernel_add_ini(n,m)\n", " \n", " \n", " def get_descripteur(self,img,using_cuda):\n", " descripteur_img=self.Relu(self.conv1(img))\n", " b,c,h,w=descripteur_img.shape\n", " couche_constante=0.5*torch.ones([1,1,h,w])\n", " if using_cuda:\n", " couche_constante=couche_constante.cuda()\n", " descripteur_img=torch.cat((descripteur_img,couche_constante),1)\n", " descripteur_img_norm=descripteur_img/torch.norm(descripteur_img,dim=1)\n", " return descripteur_img\n", " \n", " def forward(self,img,frag,using_cuda):\n", " psize=4\n", " \n", " descripteur_input1=self.get_descripteur(img,using_cuda)\n", " descripteur_input2=self.get_descripteur(frag,using_cuda)\n", " \n", " b,c,h,w=frag.shape\n", " n=int(h/psize)\n", " m=int(w/psize)\n", " \n", " for i in range(n):\n", " for j in range(m):\n", " if i==0 and j==0:\n", " map_corre=F.conv2d(descripteur_input1,get_patch(descripteur_input2,psize,i,j),padding=2)\n", " else:\n", " a=F.conv2d(descripteur_input1,get_patch(descripteur_input2,psize,i,j),padding=2)\n", " map_corre=torch.cat((map_corre,a),1)\n", " #shift\n", " map_corre=self.maxpooling(map_corre)\n", " map_corre=self.shift1(map_corre)\n", " map_corre=self.add1(map_corre)\n", " c1 = map_corre.data.numpy()\n", " \n", " \n", " n=int(n/2)\n", " m=int(m/2)\n", " if n>=2 and m>=2:\n", " map_corre=self.maxpooling(map_corre)\n", " map_corre=self.shift2(map_corre)\n", " map_corre=self.add2(map_corre)\n", " \n", " c2 = map_corre.data.numpy()\n", " \n", " \n", " n=int(n/2)\n", " m=int(m/2)\n", " if n>=2 and m>=2:\n", " map_corre=self.maxpooling(map_corre)\n", " map_corre=self.shift3(map_corre)\n", " map_corre=self.add3(map_corre)\n", " \n", " \n", " b,c,h,w=map_corre.shape\n", " map_corre=map_corre/(map_corre.max())\n", " #map_corre=(F.softmax(map_corre.reshape(1,1,h*w,1),dim=2)).reshape(b,c,h,w)\n", " return c1,c2,map_corre.data.numpy()" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [], "source": [ "cartes_filename = './trained_net/cartes_correlations_NOMB_02-15_01-37_0043'\n", "\n", "with open(cartes_filename,'rb') as pkl_file:\n", " cartes = pickle.load(pkl_file)\n", "\n", "cartes1 = cartes['carte1']\n", "cartes2 = cartes['carte2']\n", "cartes3 = cartes['carte3']\n", "\n", "N = 20\n", "c1 = cartes1[N][0,0,:,:]\n", "c2 = cartes2[N][0,0,:,:]\n", "c3 = cartes3[N][0,0,:,:]\n" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Fresque 1, fragment 2956\n" ] }, { "ename": "FileNotFoundError", "evalue": "[Errno 2] No such file or directory: './trained_net/net_trainned_NOMB_E3_02-23_07-23_0007'", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mFileNotFoundError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 18\u001b[0m \u001b[0mfresque_tensor\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mimg2tensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfresque\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 19\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 20\u001b[0;31m \u001b[0mnet\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mload_net\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnet_filename\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 21\u001b[0m \u001b[0mnet\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcpu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 22\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m\u001b[0m in \u001b[0;36mload_net\u001b[0;34m(file_path)\u001b[0m\n\u001b[1;32m 21\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 22\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mload_net\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfile_path\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 23\u001b[0;31m \u001b[0mpkl_file\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfile_path\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'rb'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 24\u001b[0m \u001b[0mnet\u001b[0m\u001b[0;34m=\u001b[0m \u001b[0mpickle\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpkl_file\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 25\u001b[0m \u001b[0mpkl_file\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mclose\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: './trained_net/net_trainned_NOMB_E3_02-23_07-23_0007'" ] } ], "source": [ "base_dir = './training_data_32/'\n", "fresque_filename = base_dir+'fresque_small{}.ppm'\n", "fresque_filename_wild = base_dir+'fresque_small*.ppm'\n", "fragment_filename = base_dir+'fragments/fresque{}/frag_dev_{:05}.ppm'\n", "fragments_filename_wild = base_dir+'fragments/fresque{}/frag_bench_*.ppm'\n", "vt_filename = base_dir+'fragments/fresque{}/vt/frag_bench_{:05}_vt.txt'\n", "net_filename = \"./trained_net/net_trainned_SLLShift_E3_03-10_21-02_0007\"\n", "fresque_id = randint(0,5)\n", "fragment_id = randint(0,3000)\n", "frag_size=32\n", "\n", "print(\"Fresque {}, fragment {}\".format(fresque_id,fragment_id))\n", "\n", "frag = cv2.imread(fragment_filename.format(fresque_id,fragment_id))\n", "frag_tensor=img2tensor(frag)\n", "\n", "fresque = cv2.imread(fresque_filename.format(fresque_id))\n", "fresque_tensor=img2tensor(fresque)\n", "\n", "net = load_net(net_filename)\n", "net.cpu()\n", "\n", "carte1,carte2,carte3=net.forward(fresque_tensor,frag_tensor,using_cuda=False)\n", "print(\"Done.\")" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "data": { "application/javascript": [ "/* Put everything inside the global mpl namespace */\n", "window.mpl = {};\n", "\n", "\n", "mpl.get_websocket_type = function() {\n", " if (typeof(WebSocket) !== 'undefined') {\n", " return WebSocket;\n", " } else if (typeof(MozWebSocket) !== 'undefined') {\n", " return MozWebSocket;\n", " } else {\n", " alert('Your browser does not have WebSocket support. ' +\n", " 'Please try Chrome, Safari or Firefox ≥ 6. ' +\n", " 'Firefox 4 and 5 are also supported but you ' +\n", " 'have to enable WebSockets in about:config.');\n", " };\n", "}\n", "\n", "mpl.figure = function(figure_id, websocket, ondownload, parent_element) {\n", " this.id = figure_id;\n", "\n", " this.ws = websocket;\n", "\n", " this.supports_binary = (this.ws.binaryType != undefined);\n", "\n", " if (!this.supports_binary) {\n", " var warnings = document.getElementById(\"mpl-warnings\");\n", " if (warnings) {\n", " warnings.style.display = 'block';\n", " warnings.textContent = (\n", " \"This browser does not support binary websocket messages. \" +\n", " \"Performance may be slow.\");\n", " }\n", " }\n", "\n", " this.imageObj = new Image();\n", "\n", " this.context = undefined;\n", " this.message = undefined;\n", " this.canvas = undefined;\n", " this.rubberband_canvas = undefined;\n", " this.rubberband_context = undefined;\n", " this.format_dropdown = undefined;\n", "\n", " this.image_mode = 'full';\n", "\n", " this.root = $('
');\n", " this._root_extra_style(this.root)\n", " this.root.attr('style', 'display: inline-block');\n", "\n", " $(parent_element).append(this.root);\n", "\n", " this._init_header(this);\n", " this._init_canvas(this);\n", " this._init_toolbar(this);\n", "\n", " var fig = this;\n", "\n", " this.waiting = false;\n", "\n", " this.ws.onopen = function () {\n", " fig.send_message(\"supports_binary\", {value: fig.supports_binary});\n", " fig.send_message(\"send_image_mode\", {});\n", " if (mpl.ratio != 1) {\n", " fig.send_message(\"set_dpi_ratio\", {'dpi_ratio': mpl.ratio});\n", " }\n", " fig.send_message(\"refresh\", {});\n", " }\n", "\n", " this.imageObj.onload = function() {\n", " if (fig.image_mode == 'full') {\n", " // Full images could contain transparency (where diff images\n", " // almost always do), so we need to clear the canvas so that\n", " // there is no ghosting.\n", " fig.context.clearRect(0, 0, fig.canvas.width, fig.canvas.height);\n", " }\n", " fig.context.drawImage(fig.imageObj, 0, 0);\n", " };\n", "\n", " this.imageObj.onunload = function() {\n", " fig.ws.close();\n", " }\n", "\n", " this.ws.onmessage = this._make_on_message_function(this);\n", "\n", " this.ondownload = ondownload;\n", "}\n", "\n", "mpl.figure.prototype._init_header = function() {\n", " var titlebar = $(\n", " '
');\n", " var titletext = $(\n", " '
');\n", " titlebar.append(titletext)\n", " this.root.append(titlebar);\n", " this.header = titletext[0];\n", "}\n", "\n", "\n", "\n", "mpl.figure.prototype._canvas_extra_style = function(canvas_div) {\n", "\n", "}\n", "\n", "\n", "mpl.figure.prototype._root_extra_style = function(canvas_div) {\n", "\n", "}\n", "\n", "mpl.figure.prototype._init_canvas = function() {\n", " var fig = this;\n", "\n", " var canvas_div = $('
');\n", "\n", " canvas_div.attr('style', 'position: relative; clear: both; outline: 0');\n", "\n", " function canvas_keyboard_event(event) {\n", " return fig.key_event(event, event['data']);\n", " }\n", "\n", " canvas_div.keydown('key_press', canvas_keyboard_event);\n", " canvas_div.keyup('key_release', canvas_keyboard_event);\n", " this.canvas_div = canvas_div\n", " this._canvas_extra_style(canvas_div)\n", " this.root.append(canvas_div);\n", "\n", " var canvas = $('');\n", " canvas.addClass('mpl-canvas');\n", " canvas.attr('style', \"left: 0; top: 0; z-index: 0; outline: 0\")\n", "\n", " this.canvas = canvas[0];\n", " this.context = canvas[0].getContext(\"2d\");\n", "\n", " var backingStore = this.context.backingStorePixelRatio ||\n", "\tthis.context.webkitBackingStorePixelRatio ||\n", "\tthis.context.mozBackingStorePixelRatio ||\n", "\tthis.context.msBackingStorePixelRatio ||\n", "\tthis.context.oBackingStorePixelRatio ||\n", "\tthis.context.backingStorePixelRatio || 1;\n", "\n", " mpl.ratio = (window.devicePixelRatio || 1) / backingStore;\n", "\n", " var rubberband = $('');\n", " rubberband.attr('style', \"position: absolute; left: 0; top: 0; z-index: 1;\")\n", "\n", " var pass_mouse_events = true;\n", "\n", " canvas_div.resizable({\n", " start: function(event, ui) {\n", " pass_mouse_events = false;\n", " },\n", " resize: function(event, ui) {\n", " fig.request_resize(ui.size.width, ui.size.height);\n", " },\n", " stop: function(event, ui) {\n", " pass_mouse_events = true;\n", " fig.request_resize(ui.size.width, ui.size.height);\n", " },\n", " });\n", "\n", " function mouse_event_fn(event) {\n", " if (pass_mouse_events)\n", " return fig.mouse_event(event, event['data']);\n", " }\n", "\n", " rubberband.mousedown('button_press', mouse_event_fn);\n", " rubberband.mouseup('button_release', mouse_event_fn);\n", " // Throttle sequential mouse events to 1 every 20ms.\n", " rubberband.mousemove('motion_notify', mouse_event_fn);\n", "\n", " rubberband.mouseenter('figure_enter', mouse_event_fn);\n", " rubberband.mouseleave('figure_leave', mouse_event_fn);\n", "\n", " canvas_div.on(\"wheel\", function (event) {\n", " event = event.originalEvent;\n", " event['data'] = 'scroll'\n", " if (event.deltaY < 0) {\n", " event.step = 1;\n", " } else {\n", " event.step = -1;\n", " }\n", " mouse_event_fn(event);\n", " });\n", "\n", " canvas_div.append(canvas);\n", " canvas_div.append(rubberband);\n", "\n", " this.rubberband = rubberband;\n", " this.rubberband_canvas = rubberband[0];\n", " this.rubberband_context = rubberband[0].getContext(\"2d\");\n", " this.rubberband_context.strokeStyle = \"#000000\";\n", "\n", " this._resize_canvas = function(width, height) {\n", " // Keep the size of the canvas, canvas container, and rubber band\n", " // canvas in synch.\n", " canvas_div.css('width', width)\n", " canvas_div.css('height', height)\n", "\n", " canvas.attr('width', width * mpl.ratio);\n", " canvas.attr('height', height * mpl.ratio);\n", " canvas.attr('style', 'width: ' + width + 'px; height: ' + height + 'px;');\n", "\n", " rubberband.attr('width', width);\n", " rubberband.attr('height', height);\n", " }\n", "\n", " // Set the figure to an initial 600x600px, this will subsequently be updated\n", " // upon first draw.\n", " this._resize_canvas(600, 600);\n", "\n", " // Disable right mouse context menu.\n", " $(this.rubberband_canvas).bind(\"contextmenu\",function(e){\n", " return false;\n", " });\n", "\n", " function set_focus () {\n", " canvas.focus();\n", " canvas_div.focus();\n", " }\n", "\n", " window.setTimeout(set_focus, 100);\n", "}\n", "\n", "mpl.figure.prototype._init_toolbar = function() {\n", " var fig = this;\n", "\n", " var nav_element = $('
');\n", " nav_element.attr('style', 'width: 100%');\n", " this.root.append(nav_element);\n", "\n", " // Define a callback function for later on.\n", " function toolbar_event(event) {\n", " return fig.toolbar_button_onclick(event['data']);\n", " }\n", " function toolbar_mouse_event(event) {\n", " return fig.toolbar_button_onmouseover(event['data']);\n", " }\n", "\n", " for(var toolbar_ind in mpl.toolbar_items) {\n", " var name = mpl.toolbar_items[toolbar_ind][0];\n", " var tooltip = mpl.toolbar_items[toolbar_ind][1];\n", " var image = mpl.toolbar_items[toolbar_ind][2];\n", " var method_name = mpl.toolbar_items[toolbar_ind][3];\n", "\n", " if (!name) {\n", " // put a spacer in here.\n", " continue;\n", " }\n", " var button = $('');\n", " button.click(method_name, toolbar_event);\n", " button.mouseover(tooltip, toolbar_mouse_event);\n", " nav_element.append(button);\n", " }\n", "\n", " // Add the status bar.\n", " var status_bar = $('');\n", " nav_element.append(status_bar);\n", " this.message = status_bar[0];\n", "\n", " // Add the close button to the window.\n", " var buttongrp = $('
');\n", " var button = $('');\n", " button.click(function (evt) { fig.handle_close(fig, {}); } );\n", " button.mouseover('Stop Interaction', toolbar_mouse_event);\n", " buttongrp.append(button);\n", " var titlebar = this.root.find($('.ui-dialog-titlebar'));\n", " titlebar.prepend(buttongrp);\n", "}\n", "\n", "mpl.figure.prototype._root_extra_style = function(el){\n", " var fig = this\n", " el.on(\"remove\", function(){\n", "\tfig.close_ws(fig, {});\n", " });\n", "}\n", "\n", "mpl.figure.prototype._canvas_extra_style = function(el){\n", " // this is important to make the div 'focusable\n", " el.attr('tabindex', 0)\n", " // reach out to IPython and tell the keyboard manager to turn it's self\n", " // off when our div gets focus\n", "\n", " // location in version 3\n", " if (IPython.notebook.keyboard_manager) {\n", " IPython.notebook.keyboard_manager.register_events(el);\n", " }\n", " else {\n", " // location in version 2\n", " IPython.keyboard_manager.register_events(el);\n", " }\n", "\n", "}\n", "\n", "mpl.figure.prototype._key_event_extra = function(event, name) {\n", " var manager = IPython.notebook.keyboard_manager;\n", " if (!manager)\n", " manager = IPython.keyboard_manager;\n", "\n", " // Check for shift+enter\n", " if (event.shiftKey && event.which == 13) {\n", " this.canvas_div.blur();\n", " // select the cell after this one\n", " var index = IPython.notebook.find_cell_index(this.cell_info[0]);\n", " IPython.notebook.select(index + 1);\n", " }\n", "}\n", "\n", "mpl.figure.prototype.handle_save = function(fig, msg) {\n", " fig.ondownload(fig, null);\n", "}\n", "\n", "\n", "mpl.find_output_cell = function(html_output) {\n", " // Return the cell and output element which can be found *uniquely* in the notebook.\n", " // Note - this is a bit hacky, but it is done because the \"notebook_saving.Notebook\"\n", " // IPython event is triggered only after the cells have been serialised, which for\n", " // our purposes (turning an active figure into a static one), is too late.\n", " var cells = IPython.notebook.get_cells();\n", " var ncells = cells.length;\n", " for (var i=0; i= 3 moved mimebundle to data attribute of output\n", " data = data.data;\n", " }\n", " if (data['text/html'] == html_output) {\n", " return [cell, data, j];\n", " }\n", " }\n", " }\n", " }\n", "}\n", "\n", "// Register the function which deals with the matplotlib target/channel.\n", "// The kernel may be null if the page has been refreshed.\n", "if (IPython.notebook.kernel != null) {\n", " IPython.notebook.kernel.comm_manager.register_target('matplotlib', mpl.mpl_figure_comm);\n", "}\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "model_filename = \"./trained_net/net_trainned_SLLShift_E3_03-10_21-02_0007\"\n", "fresque_filename = \"./training_data_32/fresque5.ppm\"\n", "net = load_net(model_filename)\n", "net = net.cpu()\n", "fresque = cv2.imread(fresque_filename)\n", "fresque_tensor = img2tensor(fresque)\n", "\n", "#descripteur = net.get_descripteur(fresque,False).detach().numpy()\n", "#descripteur_img=self.Relu(self.conv1(img))\n", "desc = net.Relu(net.conv1(fresque_tensor)).detach().numpy()\n", "\n", "fig = plt.figure(figsize=(13,3))\n", "\n", "fs = 15\n", "ax = fig.add_subplot(1,5,1)\n", "ax.set_title(\"Fresque Originale\",fontsize = fs)\n", "ax.set_yticks([])\n", "ax.set_xticks([])\n", "ax.imshow(fresque)\n", "\n", "for i in range(4):\n", " ax = fig.add_subplot(1,5,i+2)\n", " ax.set_title(\"Couche {}\".format(i+1),fontsize = fs)\n", " ax.set_yticks([])\n", " ax.set_xticks([])\n", " ax = plt.imshow(desc[0,i,:,:],cmap=cm.coolwarm)\n", " #axs[i].set_title(\"Couche {}\".format(i))\n", "\n", "fig.savefig(\"features_4_shift.png\",figsize=(20,8))\n", "fig.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Affichage pour 8 Couche" ] }, { "cell_type": "code", "execution_count": 40, "metadata": {}, "outputs": [ { "data": { "application/javascript": [ "/* Put everything inside the global mpl namespace */\n", "window.mpl = {};\n", "\n", "\n", "mpl.get_websocket_type = function() {\n", " if (typeof(WebSocket) !== 'undefined') {\n", " return WebSocket;\n", " } else if (typeof(MozWebSocket) !== 'undefined') {\n", " return MozWebSocket;\n", " } else {\n", " alert('Your browser does not have WebSocket support. ' +\n", " 'Please try Chrome, Safari or Firefox ≥ 6. ' +\n", " 'Firefox 4 and 5 are also supported but you ' +\n", " 'have to enable WebSockets in about:config.');\n", " };\n", "}\n", "\n", "mpl.figure = function(figure_id, websocket, ondownload, parent_element) {\n", " this.id = figure_id;\n", "\n", " this.ws = websocket;\n", "\n", " this.supports_binary = (this.ws.binaryType != undefined);\n", "\n", " if (!this.supports_binary) {\n", " var warnings = document.getElementById(\"mpl-warnings\");\n", " if (warnings) {\n", " warnings.style.display = 'block';\n", " warnings.textContent = (\n", " \"This browser does not support binary websocket messages. \" +\n", " \"Performance may be slow.\");\n", " }\n", " }\n", "\n", " this.imageObj = new Image();\n", "\n", " this.context = undefined;\n", " this.message = undefined;\n", " this.canvas = undefined;\n", " this.rubberband_canvas = undefined;\n", " this.rubberband_context = undefined;\n", " this.format_dropdown = undefined;\n", "\n", " this.image_mode = 'full';\n", "\n", " this.root = $('
');\n", " this._root_extra_style(this.root)\n", " this.root.attr('style', 'display: inline-block');\n", "\n", " $(parent_element).append(this.root);\n", "\n", " this._init_header(this);\n", " this._init_canvas(this);\n", " this._init_toolbar(this);\n", "\n", " var fig = this;\n", "\n", " this.waiting = false;\n", "\n", " this.ws.onopen = function () {\n", " fig.send_message(\"supports_binary\", {value: fig.supports_binary});\n", " fig.send_message(\"send_image_mode\", {});\n", " if (mpl.ratio != 1) {\n", " fig.send_message(\"set_dpi_ratio\", {'dpi_ratio': mpl.ratio});\n", " }\n", " fig.send_message(\"refresh\", {});\n", " }\n", "\n", " this.imageObj.onload = function() {\n", " if (fig.image_mode == 'full') {\n", " // Full images could contain transparency (where diff images\n", " // almost always do), so we need to clear the canvas so that\n", " // there is no ghosting.\n", " fig.context.clearRect(0, 0, fig.canvas.width, fig.canvas.height);\n", " }\n", " fig.context.drawImage(fig.imageObj, 0, 0);\n", " };\n", "\n", " this.imageObj.onunload = function() {\n", " fig.ws.close();\n", " }\n", "\n", " this.ws.onmessage = this._make_on_message_function(this);\n", "\n", " this.ondownload = ondownload;\n", "}\n", "\n", "mpl.figure.prototype._init_header = function() {\n", " var titlebar = $(\n", " '
');\n", " var titletext = $(\n", " '
');\n", " titlebar.append(titletext)\n", " this.root.append(titlebar);\n", " this.header = titletext[0];\n", "}\n", "\n", "\n", "\n", "mpl.figure.prototype._canvas_extra_style = function(canvas_div) {\n", "\n", "}\n", "\n", "\n", "mpl.figure.prototype._root_extra_style = function(canvas_div) {\n", "\n", "}\n", "\n", "mpl.figure.prototype._init_canvas = function() {\n", " var fig = this;\n", "\n", " var canvas_div = $('
');\n", "\n", " canvas_div.attr('style', 'position: relative; clear: both; outline: 0');\n", "\n", " function canvas_keyboard_event(event) {\n", " return fig.key_event(event, event['data']);\n", " }\n", "\n", " canvas_div.keydown('key_press', canvas_keyboard_event);\n", " canvas_div.keyup('key_release', canvas_keyboard_event);\n", " this.canvas_div = canvas_div\n", " this._canvas_extra_style(canvas_div)\n", " this.root.append(canvas_div);\n", "\n", " var canvas = $('');\n", " canvas.addClass('mpl-canvas');\n", " canvas.attr('style', \"left: 0; top: 0; z-index: 0; outline: 0\")\n", "\n", " this.canvas = canvas[0];\n", " this.context = canvas[0].getContext(\"2d\");\n", "\n", " var backingStore = this.context.backingStorePixelRatio ||\n", "\tthis.context.webkitBackingStorePixelRatio ||\n", "\tthis.context.mozBackingStorePixelRatio ||\n", "\tthis.context.msBackingStorePixelRatio ||\n", "\tthis.context.oBackingStorePixelRatio ||\n", "\tthis.context.backingStorePixelRatio || 1;\n", "\n", " mpl.ratio = (window.devicePixelRatio || 1) / backingStore;\n", "\n", " var rubberband = $('');\n", " rubberband.attr('style', \"position: absolute; left: 0; top: 0; z-index: 1;\")\n", "\n", " var pass_mouse_events = true;\n", "\n", " canvas_div.resizable({\n", " start: function(event, ui) {\n", " pass_mouse_events = false;\n", " },\n", " resize: function(event, ui) {\n", " fig.request_resize(ui.size.width, ui.size.height);\n", " },\n", " stop: function(event, ui) {\n", " pass_mouse_events = true;\n", " fig.request_resize(ui.size.width, ui.size.height);\n", " },\n", " });\n", "\n", " function mouse_event_fn(event) {\n", " if (pass_mouse_events)\n", " return fig.mouse_event(event, event['data']);\n", " }\n", "\n", " rubberband.mousedown('button_press', mouse_event_fn);\n", " rubberband.mouseup('button_release', mouse_event_fn);\n", " // Throttle sequential mouse events to 1 every 20ms.\n", " rubberband.mousemove('motion_notify', mouse_event_fn);\n", "\n", " rubberband.mouseenter('figure_enter', mouse_event_fn);\n", " rubberband.mouseleave('figure_leave', mouse_event_fn);\n", "\n", " canvas_div.on(\"wheel\", function (event) {\n", " event = event.originalEvent;\n", " event['data'] = 'scroll'\n", " if (event.deltaY < 0) {\n", " event.step = 1;\n", " } else {\n", " event.step = -1;\n", " }\n", " mouse_event_fn(event);\n", " });\n", "\n", " canvas_div.append(canvas);\n", " canvas_div.append(rubberband);\n", "\n", " this.rubberband = rubberband;\n", " this.rubberband_canvas = rubberband[0];\n", " this.rubberband_context = rubberband[0].getContext(\"2d\");\n", " this.rubberband_context.strokeStyle = \"#000000\";\n", "\n", " this._resize_canvas = function(width, height) {\n", " // Keep the size of the canvas, canvas container, and rubber band\n", " // canvas in synch.\n", " canvas_div.css('width', width)\n", " canvas_div.css('height', height)\n", "\n", " canvas.attr('width', width * mpl.ratio);\n", " canvas.attr('height', height * mpl.ratio);\n", " canvas.attr('style', 'width: ' + width + 'px; height: ' + height + 'px;');\n", "\n", " rubberband.attr('width', width);\n", " rubberband.attr('height', height);\n", " }\n", "\n", " // Set the figure to an initial 600x600px, this will subsequently be updated\n", " // upon first draw.\n", " this._resize_canvas(600, 600);\n", "\n", " // Disable right mouse context menu.\n", " $(this.rubberband_canvas).bind(\"contextmenu\",function(e){\n", " return false;\n", " });\n", "\n", " function set_focus () {\n", " canvas.focus();\n", " canvas_div.focus();\n", " }\n", "\n", " window.setTimeout(set_focus, 100);\n", "}\n", "\n", "mpl.figure.prototype._init_toolbar = function() {\n", " var fig = this;\n", "\n", " var nav_element = $('
');\n", " nav_element.attr('style', 'width: 100%');\n", " this.root.append(nav_element);\n", "\n", " // Define a callback function for later on.\n", " function toolbar_event(event) {\n", " return fig.toolbar_button_onclick(event['data']);\n", " }\n", " function toolbar_mouse_event(event) {\n", " return fig.toolbar_button_onmouseover(event['data']);\n", " }\n", "\n", " for(var toolbar_ind in mpl.toolbar_items) {\n", " var name = mpl.toolbar_items[toolbar_ind][0];\n", " var tooltip = mpl.toolbar_items[toolbar_ind][1];\n", " var image = mpl.toolbar_items[toolbar_ind][2];\n", " var method_name = mpl.toolbar_items[toolbar_ind][3];\n", "\n", " if (!name) {\n", " // put a spacer in here.\n", " continue;\n", " }\n", " var button = $('