This commit is contained in:
higepi 2022-10-06 17:47:26 +02:00
parent 0ed986773b
commit ac5f8433e0
2 changed files with 65 additions and 11 deletions

View file

@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
@ -255,15 +255,6 @@
"source": [
"print(wv.doesnt_match(['wood', 'oak', 'tree', 'iron', 'leaf']))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print(wv.most_similar(positive=['The largest country is']))"
]
}
],
"metadata": {

View file

@ -1,4 +1,9 @@
import gensim.downloader as api
from sklearn.decomposition import IncrementalPCA # inital reduction
from sklearn.manifold import TSNE # final reduction
import numpy as np # array handling
wv = api.load('word2vec-google-news-300')
# Affichage de quelques mots du vocabulaire
@ -17,4 +22,62 @@ vec_woman = wv['woman']
result = wv.most_similar(positive=(vec_father - vec_man + vec_woman), topn=1)
print(result)
print(wv.most_similar(positive=['The largest country is']))
## Visualisation
def reduce_dimensions(model):
num_dimensions = 2 # final num dimensions (2D, 3D, etc)
# extract the words & their vectors, as numpy arrays
vectors = np.asarray(model.wv.vectors)
labels = np.asarray(model.wv.index_to_key) # fixed-width numpy strings
# reduce using t-SNE
tsne = TSNE(n_components=num_dimensions, random_state=0)
vectors = tsne.fit_transform(vectors)
x_vals = [v[0] for v in vectors]
y_vals = [v[1] for v in vectors]
return x_vals, y_vals, labels
x_vals, y_vals, labels = reduce_dimensions(model)
def plot_with_plotly(x_vals, y_vals, labels, plot_in_notebook=True):
from plotly.offline import init_notebook_mode, iplot, plot
import plotly.graph_objs as go
trace = go.Scatter(x=x_vals, y=y_vals, mode='text', text=labels)
data = [trace]
if plot_in_notebook:
init_notebook_mode(connected=True)
iplot(data, filename='word-embedding-plot')
else:
plot(data, filename='word-embedding-plot.html')
def plot_with_matplotlib(x_vals, y_vals, labels):
import matplotlib.pyplot as plt
import random
random.seed(0)
plt.figure(figsize=(12, 12))
plt.scatter(x_vals, y_vals)
#
# Label randomly subsampled 25 data points
#
indices = list(range(len(labels)))
selected_indices = random.sample(indices, 25)
for i in selected_indices:
plt.annotate(labels[i], (x_vals[i], y_vals[i]))
try:
get_ipython()
except Exception:
plot_function = plot_with_matplotlib
else:
plot_function = plot_with_plotly
plot_function(x_vals, y_vals, labels)