corr 2
This commit is contained in:
parent
0ed986773b
commit
ac5f8433e0
2 changed files with 65 additions and 11 deletions
|
@ -2,7 +2,7 @@
|
|||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -255,15 +255,6 @@
|
|||
"source": [
|
||||
"print(wv.doesnt_match(['wood', 'oak', 'tree', 'iron', 'leaf']))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(wv.most_similar(positive=['The largest country is']))"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
|
65
wordvec.py
65
wordvec.py
|
@ -1,4 +1,9 @@
|
|||
import gensim.downloader as api
|
||||
from sklearn.decomposition import IncrementalPCA # inital reduction
|
||||
from sklearn.manifold import TSNE # final reduction
|
||||
import numpy as np # array handling
|
||||
|
||||
|
||||
wv = api.load('word2vec-google-news-300')
|
||||
|
||||
# Affichage de quelques mots du vocabulaire
|
||||
|
@ -17,4 +22,62 @@ vec_woman = wv['woman']
|
|||
result = wv.most_similar(positive=(vec_father - vec_man + vec_woman), topn=1)
|
||||
print(result)
|
||||
|
||||
print(wv.most_similar(positive=['The largest country is']))
|
||||
## Visualisation
|
||||
|
||||
def reduce_dimensions(model):
|
||||
num_dimensions = 2 # final num dimensions (2D, 3D, etc)
|
||||
|
||||
# extract the words & their vectors, as numpy arrays
|
||||
vectors = np.asarray(model.wv.vectors)
|
||||
labels = np.asarray(model.wv.index_to_key) # fixed-width numpy strings
|
||||
|
||||
# reduce using t-SNE
|
||||
tsne = TSNE(n_components=num_dimensions, random_state=0)
|
||||
vectors = tsne.fit_transform(vectors)
|
||||
|
||||
x_vals = [v[0] for v in vectors]
|
||||
y_vals = [v[1] for v in vectors]
|
||||
return x_vals, y_vals, labels
|
||||
|
||||
|
||||
x_vals, y_vals, labels = reduce_dimensions(model)
|
||||
|
||||
def plot_with_plotly(x_vals, y_vals, labels, plot_in_notebook=True):
|
||||
from plotly.offline import init_notebook_mode, iplot, plot
|
||||
import plotly.graph_objs as go
|
||||
|
||||
trace = go.Scatter(x=x_vals, y=y_vals, mode='text', text=labels)
|
||||
data = [trace]
|
||||
|
||||
if plot_in_notebook:
|
||||
init_notebook_mode(connected=True)
|
||||
iplot(data, filename='word-embedding-plot')
|
||||
else:
|
||||
plot(data, filename='word-embedding-plot.html')
|
||||
|
||||
|
||||
def plot_with_matplotlib(x_vals, y_vals, labels):
|
||||
import matplotlib.pyplot as plt
|
||||
import random
|
||||
|
||||
random.seed(0)
|
||||
|
||||
plt.figure(figsize=(12, 12))
|
||||
plt.scatter(x_vals, y_vals)
|
||||
|
||||
#
|
||||
# Label randomly subsampled 25 data points
|
||||
#
|
||||
indices = list(range(len(labels)))
|
||||
selected_indices = random.sample(indices, 25)
|
||||
for i in selected_indices:
|
||||
plt.annotate(labels[i], (x_vals[i], y_vals[i]))
|
||||
|
||||
try:
|
||||
get_ipython()
|
||||
except Exception:
|
||||
plot_function = plot_with_matplotlib
|
||||
else:
|
||||
plot_function = plot_with_plotly
|
||||
|
||||
plot_function(x_vals, y_vals, labels)
|
||||
|
|
Loading…
Reference in a new issue