Praxisbeispiel - Bildklassifikation mit CNNs¶
Ziel:¶
Einführung in Convolutional Neural Networks mit TensorFlow/Keras anhand eines Bildklassifikationsproblems.
1. Bibliotheken laden¶
[1]:
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.datasets import mnist
import matplotlib.pyplot as plt
2. Daten laden¶
[2]:
(x_train, y_train), (x_test, y_test) = mnist.load_data()
3. Daten einsehen¶
[3]:
x_train
[3]:
array([[[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
...,
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0]],
[[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
...,
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0]],
[[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
...,
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0]],
...,
[[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
...,
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0]],
[[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
...,
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0]],
[[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
...,
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0]]], shape=(60000, 28, 28), dtype=uint8)
4. Daten vorbereiten (Data Preprocessing)¶
Normalisierung¶
Normalisieren von Bilderdaten ist wiederum anders als von numerischen Daten.
Der Grund für die Normalisierung der Bilder(daten) ist die Vermeidung der Möglichkeit von explodierenden Gradienten aufgrund des großen Pixelbereichs [0, 255] und die Verbesserung der Konvergenzgeschwindigkeit. Daher kann entweder
man jedes Bild normalisieren, so dass der Pixelbereich sich in [-1, 1] befindet oder
man teilt jeden Wert durch den maximalen Pixelwert, d.h. 255, so dass der Bereich der Pixel im Bereich [0, 1] liegt.
Ein weiterer Grund für die Normalisierung von Bilddaten ist wenn man Transfer Learning verwendet. Wenn z. B. ein bereits trainiertes Modell verwendet wird, das mit Bildern trainiert wurde, deren Pixel im Bereich [0, 1] liegen, sollte man sicherstellen, dass die neuen Werte, die man dem Modell liefert, im gleichen Bereich liegen. Andernfalls werden die Ergebnisse verfälscht werden.
[4]:
# Normalisierung
x_train, x_test = x_train / 255.0, x_test / 255.0
Dimension erweitern¶
Zuden existierenden Dimensionen der Bilddaten fügen wir eine neue Dimension hinzu. Diese neue Dimension stellt die Anzahl der in den Daten vorhandenen Kanäle dar.
Bei Farbbildern wären dies 3 Kanäle, die den roten, grünen und blauen Kanal darstellen. In diesem Fall handelt es sich um Schwarz-Weiß-Bilder, so dass nur 1 Kanal vorhanden ist.
[5]:
# Dimension erweitern
x_train = x_train[..., tf.newaxis]
x_test = x_test[..., tf.newaxis]
5. Modell definieren¶
CNN Layers:
Conv2D:
Die am häufigsten verwendete Art der Faltung ist die 2D-Faltungsschicht und wird üblicherweise als conv2D abgekürzt. Ein Filter oder ein Kernel in einer conv2D-Schicht „gleitet“ über die 2D-Eingangsdaten und führt eine elementweise Multiplikation durch. Das Ergebnis ist die Summierung der Ergebnisse zu einem einzigen Ausgabepixel.
Parameter bei der Erstellung einer Conv2D Schicht:
32: Anzahl von Filtern in dieser Convolution-Schicht. Hierfür wird immer empfohlen, Potenzen von 2 als Werte zu verwenden.
(3, 3): bestimmt die Dimensionen des Kernels. Übliche Abmessungen sind 1×1, 3×3, 5×5 oder 7×7, entsprechend als (1, 1), (3, 3), (5, 5) oder (7, 7)-Tupel übergeben. Es muss hier eine ganze Zahl oder ein Tupel/Liste von 2 ganzen Zahlen, die die Höhe und Breite des 2D-Faltungsfensters angeben. Zudem muss dieser Parameter eine ungerade ganze Zahl sein.
activation=“..“: gibt den Namen der Aktivierungsfunktion an, die nach der Faltung/convolution verwendet werden soll. (siehe unten)
MaxPooling2D (more details: https://www.geeksforgeeks.org/cnn-introduction-to-pooling-layer/?ref=header_outind)
Die Pooling-Schicht wird in CNNs verwendet, um die räumlichen Dimensionen (Breite und Höhe) der eingegebenen Merkmalskarten zu reduzieren und gleichzeitig die wichtigsten Informationen beizubehalten. Dabei wird ein zweidimensionaler Filter über jeden Kanal einer Merkmalskarte gezogen und die Merkmale innerhalb des vom Filter abgedeckten Bereichs zusammengefasst.
Zudem hilft es die Dimensionalität zu verringern, da Pooling-Schichten die räumliche Größe der Feature-Matrix reduzieren, somit die Anzahl der Parameter und Berechnungen im Network verringert wird. So wird das Modell schneller und effizienter. Außerdem trägt die Reduzierung der räumlichen Dimensionen dazu bei, Overfitting zu verhindern.
Flatten:
Eine flache Schicht des neuronalen Netzes wird verwendet, um die mehrdimensionale Ausgabe der vorhergehenden Schicht in ein eindimensionales Feld umzuwandeln, bevor sie zur weiteren Verarbeitung in eine vollständig verbundene Schicht (dense layers) eingespeist wird.
Zudem reduziert es die Dimension in den Daten und vereinfacht die Modellarchitektur.
Dense: Die Dense Schicht ist eine vollständig verbundene Schicht.
(typische) Aktivierungsfunktionen¶
relu:
Die ReLU-Aktivierungsfunktion wird verwendet, um Nichtlinearität in ein neuronales Netz einzuführen. Sie trägt dazu bei, das Problem des verschwindenden Gradienten beim Training von Modellen des maschinellen Lernens zu entschärfen, und ermöglicht es neuronalen Netzen, komplexere Beziehungen in Daten zu lernen. Wenn eine Modelleingabe positiv ist, gibt die ReLU-Funktion denselben Wert aus. Wenn eine Modelleingabe negativ ist, gibt die ReLU-Funktion den Wert Null aus.
softmax:
Die Softmax-Funktion, die häufig in der letzten Schicht eines neuronalen Netzmodells für Klassifizierungsaufgaben verwendet wird, wandelt rohe Ausgabeergebnisse - auch als Logits bekannt - in Wahrscheinlichkeiten um, indem sie den Exponentialwert jeder Ausgabe nimmt und diese Werte normalisiert, indem sie durch die Summe aller Exponentialwerte dividiert wird.
[6]:
# Modell definieren
model = models.Sequential(
[
layers.Conv2D(32, (3, 3), activation="relu", input_shape=(28, 28, 1)),
layers.MaxPooling2D((2, 2)),
layers.Conv2D(64, (3, 3), activation="relu"),
layers.MaxPooling2D((2, 2)),
layers.Conv2D(64, (3, 3), activation="relu"),
layers.Flatten(),
layers.Dense(64, activation="relu"),
layers.Dense(10, activation="softmax"),
]
)
/Users/veit/cusy/trn/ai-tutorial/.venv/lib/python3.13/site-packages/keras/src/layers/convolutional/base_conv.py:113: UserWarning: Do not pass an `input_shape`/`input_dim` argument to a layer. When using Sequential models, prefer using an `Input(shape)` object as the first layer in the model instead.
super().__init__(activity_regularizer=activity_regularizer, **kwargs)
Gute Visualisierungen solch ähnlicher Struktur können hier gefunden werden:
https://miro.medium.com/v2/resize:fit:1400/format:webp/1*vkQ0hXDaQv57sALXAJquxA.jpeg
https://miro.medium.com/v2/resize:fit:1400/format:webp/1*uAeANQIOQPqWZnnuH-VEyw.jpeg
(Credits: Sumit Saha, Towards Data Science — „A Comprehensive Guide to Convolutional Neural Networks — the ELI5 way“; Artikel auf Medium, Link im Build ausgelassen da Medium oft 403 zurückgibt.)
6. Modell kompilieren¶
[7]:
# Modell kompilieren
model.compile(
optimizer="adam",
loss="sparse_categorical_crossentropy",
metrics=["accuracy"],
)
7. Training¶
[8]:
model.fit(x_train, y_train, epochs=5, validation_data=(x_test, y_test))
Epoch 1/5
1875/1875 ━━━━━━━━━━━━━━━━━━━━ 12s 6ms/step - accuracy: 0.9534 - loss: 0.1507 - val_accuracy: 0.9856 - val_loss: 0.0488
Epoch 2/5
1875/1875 ━━━━━━━━━━━━━━━━━━━━ 13s 7ms/step - accuracy: 0.9859 - loss: 0.0463 - val_accuracy: 0.9840 - val_loss: 0.0533
Epoch 3/5
1875/1875 ━━━━━━━━━━━━━━━━━━━━ 14s 7ms/step - accuracy: 0.9890 - loss: 0.0352 - val_accuracy: 0.9899 - val_loss: 0.0327
Epoch 4/5
1875/1875 ━━━━━━━━━━━━━━━━━━━━ 15s 8ms/step - accuracy: 0.9917 - loss: 0.0267 - val_accuracy: 0.9890 - val_loss: 0.0365
Epoch 5/5
1875/1875 ━━━━━━━━━━━━━━━━━━━━ 16s 8ms/step - accuracy: 0.9933 - loss: 0.0206 - val_accuracy: 0.9913 - val_loss: 0.0293
[8]:
<keras.src.callbacks.history.History at 0x306ac42f0>
[9]:
model.get_weights()
[9]:
[array([[[[-2.20920473e-01, -7.94491619e-02, -6.55768160e-03,
-1.67344511e-01, 4.06573564e-02, 9.16692317e-02,
1.59453467e-01, 7.33715445e-02, 5.15761822e-02,
1.51928384e-02, 2.35505372e-01, 4.93809208e-02,
1.42657220e-01, 2.56727114e-02, 1.32331820e-02,
-2.01917097e-01, -4.89609018e-02, -1.42563239e-01,
-6.69733211e-02, 2.49209125e-02, -7.21121654e-02,
-1.09157324e-01, -8.19490664e-03, 1.74524680e-01,
-1.00255124e-01, -1.63895801e-01, -3.22987698e-02,
7.25202933e-02, -7.89696723e-02, -7.60411248e-02,
1.13907091e-01, 8.21158662e-02]],
[[ 5.29988930e-02, -3.60815167e-01, 8.38052407e-02,
8.62759426e-02, -2.92696562e-02, 1.95439041e-01,
-8.02631304e-02, 3.64616476e-02, -5.28480746e-02,
9.72734243e-02, 3.02234530e-01, 6.07382506e-02,
-4.71719094e-02, -2.30713248e-01, 1.07276186e-01,
-9.69349965e-02, 1.92033976e-01, 1.25463903e-01,
-5.52276522e-02, -2.13665634e-01, -2.00767651e-01,
1.63289726e-01, -2.85120636e-01, 1.77466944e-02,
-2.83153430e-02, 5.96216954e-02, 1.53643191e-01,
5.43419598e-03, -2.71694232e-02, 1.95470929e-01,
2.06592754e-01, 1.70671437e-02]],
[[-5.47746487e-04, -1.36313200e-01, 8.81105885e-02,
-2.55618989e-02, 7.11751878e-02, -6.61918297e-02,
-1.64077580e-01, -3.78607780e-01, 4.54623550e-02,
-2.05549300e-01, -1.88396767e-01, -1.82072908e-01,
-1.70547172e-01, -3.04500729e-01, 2.47008756e-01,
2.53357351e-01, -7.01055750e-02, 6.28316775e-02,
1.77779451e-01, -1.61328502e-02, -3.20997864e-01,
2.01325834e-01, -2.91512072e-01, -2.32901767e-01,
2.49616206e-02, 1.62868239e-02, -5.37151694e-02,
1.03772536e-01, 1.63174868e-01, -6.53485805e-02,
2.14372754e-01, -2.71023810e-01]]],
[[[ 1.42177984e-01, -2.17010319e-01, 9.53710154e-02,
-4.18881699e-02, 1.02515344e-03, -2.59148449e-01,
1.81316257e-01, -9.85033810e-02, 2.48357803e-02,
2.43994445e-02, 2.73050010e-01, -3.21224742e-02,
2.15805799e-01, -1.41091710e-02, -1.48583993e-01,
-4.66286957e-01, -2.22558990e-01, 1.09945469e-01,
-1.88796729e-01, 7.46641532e-02, 1.89557299e-01,
1.98169306e-01, 5.76510653e-02, 5.60208224e-02,
1.07446268e-01, -2.59754658e-01, 1.21596009e-01,
1.59706790e-02, 1.14251323e-01, 1.59909436e-03,
8.54656473e-02, -9.69819427e-02]],
[[ 1.88999146e-01, 7.98606724e-02, 1.16631072e-02,
-3.39207388e-02, -3.54416929e-02, 8.63766763e-03,
-2.72432286e-02, -6.11493349e-01, 1.58249393e-01,
2.14425549e-01, -1.93805862e-02, 9.79106128e-02,
9.72954091e-03, 1.18725277e-01, 1.04483915e-02,
1.14888199e-01, 2.30073556e-01, 4.49464247e-02,
-1.67520523e-01, -2.00613514e-02, 1.89994350e-01,
-9.05582309e-03, 1.67186335e-02, -1.18125580e-01,
1.98961824e-01, 2.28117947e-02, -5.56683308e-03,
8.84333551e-02, 2.03834504e-01, 1.20300345e-01,
9.64103267e-02, 1.91511348e-01]],
[[ 1.11986905e-01, 2.57325739e-01, -1.25424592e-02,
7.45124668e-02, 5.58233149e-02, -3.95383984e-02,
-2.58078814e-01, -6.42040223e-02, 5.15178069e-02,
1.18728228e-01, -4.18353975e-01, 1.07019544e-01,
-3.48158747e-01, -1.88075006e-01, 2.58432478e-01,
2.29934245e-01, -9.77288261e-02, 9.79964361e-02,
-1.46402325e-02, -8.91939923e-02, 1.14863291e-01,
-8.29333961e-02, -1.08123459e-01, -3.71893406e-01,
9.89398956e-02, 1.76147595e-01, 5.95058911e-02,
-6.12055920e-02, 5.77863082e-02, -8.17916244e-02,
3.68913338e-02, 1.32307917e-01]]],
[[[ 6.95364773e-02, 2.51678437e-01, -7.89162815e-02,
7.30204210e-02, -5.05873002e-02, -2.11244464e-01,
2.47483812e-02, -5.00733912e-01, -1.59897149e-01,
-1.66043520e-01, 3.22126932e-02, -1.77697733e-03,
1.44355744e-01, 1.34088621e-01, -3.73311490e-01,
-1.74173132e-01, -2.68429250e-01, -3.05692144e-02,
-3.95541161e-01, 1.52650297e-01, -7.71371648e-02,
-1.06454082e-01, 1.21975549e-01, 2.92527050e-01,
1.42879114e-01, 4.06415462e-02, 1.45763099e-01,
-2.76064485e-01, 1.98648381e-03, -3.79160307e-02,
-3.14531267e-01, -4.84688580e-02]],
[[-1.32999852e-01, 1.15969032e-01, 8.20346400e-02,
1.49580285e-01, 8.36936533e-02, -2.40755174e-02,
2.23825455e-01, 2.18706429e-02, -2.20397890e-01,
-1.25085041e-01, -4.10221338e-01, 1.22000620e-01,
1.70271412e-01, 1.73176065e-01, -3.11257094e-01,
-1.08548351e-01, 2.13634506e-01, 1.22949667e-01,
1.02562025e-01, 5.71810305e-02, 2.48645842e-02,
-1.72762349e-01, 2.07926556e-01, 2.27537408e-01,
-7.71648884e-02, 8.04939717e-02, 8.10799748e-02,
-5.40046021e-02, -1.46931589e-01, 1.57796115e-01,
-3.49180013e-01, -9.06167030e-02]],
[[-1.37200013e-01, 1.06969319e-01, -6.87700957e-02,
5.88214435e-02, -2.52333209e-02, 5.31174615e-02,
-2.06401169e-01, 3.80136132e-01, -1.18226081e-01,
5.33868819e-02, -9.06113386e-02, 7.10080052e-03,
-1.97398961e-01, 1.79805905e-01, -9.28899422e-02,
1.52331367e-01, -1.22799454e-02, 3.64275798e-02,
1.87574059e-01, -1.57145098e-01, 1.05322644e-01,
-1.85443774e-01, 2.74148762e-01, -4.11142455e-03,
-1.98894858e-01, 1.44984841e-01, -6.28811195e-02,
9.43158865e-02, -2.57780373e-01, 2.70449929e-02,
-1.70540676e-01, 1.29828125e-01]]]], dtype=float32),
array([-0.11964628, -0.00853754, -0.12067217, -0.14822854, -0.12620844,
-0.06389564, -0.0206668 , 0.10146309, -0.01160986, -0.0636593 ,
0.01471836, -0.14386651, -0.01522872, -0.01808114, 0.03797244,
-0.01213757, -0.10286602, -0.15156734, -0.03810443, -0.05511363,
-0.02217144, -0.01239665, -0.01430051, -0.02510639, -0.11962832,
-0.06226835, -0.11643692, -0.08747373, -0.07341461, -0.13219371,
-0.00681282, -0.0792897 ], dtype=float32),
array([[[[-2.87141770e-01, 2.72438042e-02, -5.11086211e-02, ...,
-2.40444839e-02, 1.01456687e-01, 8.34921598e-02],
[ 9.85200480e-02, -3.03475372e-02, -1.06990263e-01, ...,
1.24072209e-02, 1.46683957e-03, 6.44328520e-02],
[-3.54707181e-01, -8.89599137e-03, 2.07795147e-02, ...,
-4.01142761e-02, 9.51791555e-02, -4.14183103e-02],
...,
[-2.02539966e-01, -3.63311879e-02, -2.17840932e-02, ...,
-6.80495948e-02, 1.22527994e-01, 1.30075544e-01],
[-1.31371379e-01, 7.45408759e-02, 1.43486951e-02, ...,
-1.47490846e-02, -1.23880312e-01, -2.19554886e-01],
[-2.69234836e-01, 4.33345251e-02, 6.68406114e-02, ...,
-1.35258496e-01, 3.69792357e-02, -2.93387752e-02]],
[[-1.70327663e-01, -7.54242241e-02, 8.91283154e-03, ...,
4.74838354e-03, 5.11242747e-02, 7.47668967e-02],
[-7.14420602e-02, -1.40910763e-02, -6.69435337e-02, ...,
1.48552790e-01, -5.73910326e-02, -8.31986070e-02],
[-1.86255619e-01, -8.67929980e-02, -1.16934348e-02, ...,
-3.16584222e-02, -5.45332991e-02, -1.20528033e-02],
...,
[-1.63636878e-01, 1.20767683e-01, 1.70605853e-02, ...,
-1.79877669e-01, -7.97864422e-02, 5.68156652e-02],
[ 4.56921235e-02, -2.06424311e-01, 1.15143880e-01, ...,
6.30126074e-02, 2.74168272e-02, -1.09053887e-01],
[-1.16048634e-01, 8.29766542e-02, -2.95195766e-02, ...,
-7.05948025e-02, -3.92884351e-02, -3.66110057e-02]],
[[-2.18459725e-01, -1.14658348e-01, -2.64537372e-02, ...,
3.04673240e-02, -7.59111866e-02, 7.08029866e-02],
[-2.48813793e-01, 7.48770386e-02, 3.39205116e-02, ...,
4.24497165e-02, -1.24660566e-01, -1.75484538e-01],
[ 2.22609541e-03, -8.50358829e-02, -5.62276132e-03, ...,
9.36959088e-02, -2.45707426e-02, 5.13431355e-02],
...,
[-8.24836642e-02, 6.77229790e-03, -9.19606239e-02, ...,
3.59022468e-02, -5.44607081e-02, 7.43014291e-02],
[ 1.80594563e-01, -1.15248241e-01, 8.66105705e-02, ...,
-1.86142817e-01, 5.24160229e-02, -2.89593667e-01],
[-1.07689261e-01, -2.31931537e-01, -1.13854095e-01, ...,
-7.43491426e-02, 2.18279902e-02, -8.69931057e-02]]],
[[[ 1.15748502e-01, 5.72851598e-02, 2.51108687e-02, ...,
-8.96187946e-02, 1.04319990e-01, -3.71962748e-02],
[ 1.82982773e-01, -4.83286101e-03, -1.28907442e-01, ...,
-1.11754186e-01, -6.75523803e-02, -4.66443114e-02],
[ 3.09388675e-02, 6.55728728e-02, -5.33341244e-02, ...,
2.35964637e-02, 6.05051368e-02, 6.25400692e-02],
...,
[ 6.82009384e-02, -7.47776553e-02, 1.71388611e-02, ...,
-8.33416060e-02, 3.54118571e-02, 1.21152230e-01],
[ 2.89506763e-01, 1.62052214e-01, 1.35738179e-02, ...,
9.97384042e-02, -8.80349651e-02, -1.25831261e-01],
[ 6.69354275e-02, 3.15060541e-02, 9.49378535e-02, ...,
-1.20546669e-01, 2.51164548e-02, -2.73378957e-02]],
[[-6.62067309e-02, -1.27493262e-01, 1.90546200e-01, ...,
-7.00002015e-02, 1.87407099e-02, 7.63981277e-03],
[ 8.92735645e-02, -2.33126190e-02, 1.99162975e-01, ...,
5.10066077e-02, -1.35354355e-01, 1.17260136e-01],
[-2.89361179e-02, 4.83213887e-02, 8.49392712e-02, ...,
-7.56066144e-02, 1.99991986e-02, -7.93926269e-02],
...,
[-4.43059132e-02, 2.80666295e-02, -2.20480189e-01, ...,
-1.93504795e-01, -8.76336545e-02, 2.26779915e-02],
[-1.76911280e-01, 6.28711507e-02, 1.97020710e-01, ...,
-5.29049262e-02, -1.16330191e-01, -4.87897173e-02],
[-3.60079855e-02, 1.44704401e-01, 8.19573402e-02, ...,
8.52851272e-02, 4.86281924e-02, -1.21772606e-02]],
[[-1.99234173e-01, -1.16440035e-01, 1.86773926e-01, ...,
-2.26203632e-02, 3.25202607e-02, 7.40338638e-02],
[-9.39182490e-02, 1.32643972e-02, 7.39964992e-02, ...,
1.37882322e-01, 6.55246004e-02, 1.47754569e-02],
[-6.64413646e-02, -1.96749461e-03, 2.17679627e-02, ...,
1.05354842e-02, -3.58506516e-02, -2.94759739e-02],
...,
[-7.96286911e-02, -1.87324286e-02, -1.46458060e-01, ...,
1.10436425e-01, -2.83425394e-02, 8.48354921e-02],
[-2.12218508e-01, -1.49263233e-01, 1.61070332e-01, ...,
-1.30533174e-01, -7.25617409e-02, 1.50137311e-02],
[-1.06645055e-01, 7.01546073e-02, -5.47982287e-03, ...,
1.36719510e-01, -1.37953460e-01, 6.12698458e-02]]],
[[[ 7.90937394e-02, -3.04283248e-03, -2.13030726e-01, ...,
-1.97056204e-01, 1.70838699e-01, 1.06945910e-01],
[-1.39185111e-03, 1.04038410e-01, -1.64143920e-01, ...,
-5.09495437e-02, -4.16874848e-02, -6.91269934e-02],
[ 1.51556566e-01, -1.30391970e-01, -8.60519260e-02, ...,
-8.31990689e-02, 2.68980358e-02, -8.78927577e-03],
...,
[ 1.69954374e-01, -3.19773331e-02, -1.34970054e-01, ...,
-2.81187028e-01, 6.58432543e-02, -1.05596848e-01],
[ 7.10185468e-02, -1.74731594e-02, 3.74397226e-02, ...,
5.72115779e-02, -2.69202795e-02, 4.12565559e-01],
[ 9.22055170e-02, -6.78981692e-02, -1.34359986e-01, ...,
-1.15458816e-01, 1.87833626e-02, 9.39635113e-02]],
[[ 8.26287791e-02, 3.35943513e-02, -6.68488666e-02, ...,
-2.88449496e-01, 8.33515264e-03, 1.11621633e-01],
[-2.13350151e-02, -8.93757585e-03, -1.97948098e-01, ...,
1.58864949e-02, -6.34026304e-02, -5.35957329e-02],
[ 1.33022532e-01, -8.74736384e-02, -1.36090174e-01, ...,
-3.66353057e-02, -3.16930078e-02, 1.15373187e-01],
...,
[ 9.74067375e-02, -1.07977875e-01, -2.71959424e-01, ...,
4.87390943e-02, 2.47190297e-02, -1.40682802e-01],
[-2.07911119e-01, 1.28951399e-02, 8.14483594e-03, ...,
-2.23321274e-01, 4.29714397e-02, 3.81014705e-01],
[ 6.49554059e-02, 2.36556102e-02, -1.01164579e-01, ...,
2.49743666e-02, 2.68268809e-02, -7.28107840e-02]],
[[-1.30449221e-01, -1.08525708e-01, -3.88075076e-02, ...,
4.06191573e-02, -8.30720291e-02, 2.38718698e-04],
[-9.00636911e-02, -1.04837172e-01, -1.56476468e-01, ...,
5.87328039e-02, -2.01005116e-02, -1.55294865e-01],
[-1.15635164e-01, -2.36605536e-02, -2.63836324e-01, ...,
1.96182132e-02, 3.76341045e-02, 3.59329619e-02],
...,
[-4.14120518e-02, 3.38348635e-02, -1.60050854e-01, ...,
6.38134032e-02, -7.75649399e-02, -4.17994931e-02],
[-1.14109345e-01, 8.61790404e-02, -4.43402417e-02, ...,
-1.10245809e-01, -1.29732698e-01, 2.76057720e-01],
[-6.74327388e-02, 1.60675079e-01, -2.46481165e-01, ...,
1.09411940e-01, 4.32899371e-02, -7.89948832e-03]]]],
shape=(3, 3, 32, 64), dtype=float32),
array([ 0.06731296, -0.02459595, -0.02482229, -0.02666466, -0.05645316,
0.01924576, -0.08723029, -0.10761139, -0.048223 , -0.02819997,
-0.06707903, -0.05372068, -0.07045425, -0.10744666, -0.01075376,
-0.11981165, 0.00394488, -0.0332419 , -0.08568995, -0.05768878,
-0.12302567, -0.01584917, -0.0035959 , -0.05352783, -0.00870674,
-0.04014177, -0.01989994, -0.09150246, -0.01127943, 0.01514494,
-0.04301269, 0.00545209, -0.07562286, 0.00512845, -0.03049872,
-0.05419792, -0.0653836 , -0.06375174, 0.02577538, -0.01500314,
-0.01567873, 0.03286972, 0.01425118, -0.04360417, -0.04756039,
-0.05892737, -0.03354201, 0.02487741, -0.07415319, -0.09755213,
-0.05163696, 0.0085671 , -0.00129141, -0.07027808, -0.03040549,
-0.01221041, -0.00892876, -0.03424224, -0.07298737, -0.08493306,
-0.01933318, -0.02557231, -0.09205134, -0.07720712], dtype=float32),
array([[[[ 1.38326483e-02, -5.60028590e-02, -1.11946099e-01, ...,
-9.23848376e-02, -4.39743586e-02, 1.12006485e-01],
[ 1.07546322e-01, -4.26881202e-02, -1.57403445e-03, ...,
3.96482274e-02, 2.21814346e-02, 2.42187977e-01],
[-4.24904115e-02, 3.52870498e-04, -6.20885938e-02, ...,
1.42450957e-02, -2.14181077e-02, -1.56311579e-02],
...,
[-7.11930841e-02, -7.82431290e-02, -1.05166230e-02, ...,
-2.45241448e-02, -6.72213510e-02, -1.86088294e-01],
[-2.67039631e-02, -2.32834220e-02, 5.95758809e-03, ...,
-2.33040042e-02, -6.80534318e-02, 1.66961864e-01],
[ 6.92032427e-02, 3.10762487e-02, -7.18235523e-02, ...,
5.84568316e-03, 1.51486741e-02, -4.54714745e-02]],
[[ 8.95656124e-02, -3.28155085e-02, -1.65620461e-01, ...,
-9.39185619e-02, -8.11804458e-02, 1.16463840e-01],
[ 5.78018427e-02, 1.85013115e-02, -7.45760649e-02, ...,
-1.41571816e-02, -3.58420704e-03, -6.55916799e-03],
[-9.83992741e-02, -7.88790286e-02, -7.09981397e-02, ...,
-5.64843230e-02, -4.88687083e-02, -8.44919086e-02],
...,
[ 7.02290796e-03, -7.61678442e-02, -1.21449027e-02, ...,
1.31721184e-01, 2.19784323e-02, 2.32389960e-02],
[ 7.33735859e-02, 1.03031527e-02, -1.71359852e-02, ...,
3.50432955e-02, -9.20630395e-02, 2.83694472e-02],
[-2.53777597e-02, -1.05815500e-01, -1.22579671e-02, ...,
-1.20061614e-01, -1.14446633e-01, -2.14739367e-01]],
[[-1.54941082e-01, -7.03022033e-02, -4.40699607e-02, ...,
-6.72838762e-02, 5.89839853e-02, -2.56626815e-01],
[-6.24654368e-02, 4.18878794e-02, -9.24368799e-02, ...,
-1.66033939e-01, 1.26917372e-02, -1.43030174e-02],
[-5.73151670e-02, -4.93946970e-02, -2.94560031e-03, ...,
8.56453329e-02, 5.83696598e-03, 9.15837362e-02],
...,
[-1.01416223e-01, -1.69722792e-02, -2.14953441e-02, ...,
-4.74646278e-02, 1.60891586e-03, -7.84632862e-02],
[-8.33573714e-02, -1.05455138e-01, -6.92916960e-02, ...,
1.16767764e-01, -4.31388710e-03, -3.08036625e-01],
[-1.53677627e-01, -1.79537218e-02, -4.93312813e-02, ...,
-1.89838633e-02, -8.01014751e-02, -1.23914368e-01]]],
[[[-2.21265808e-01, -3.21912989e-02, -7.94314817e-02, ...,
-3.22166877e-03, -4.58807349e-02, 3.29253405e-01],
[ 6.13443255e-02, -4.25678976e-02, 6.38110489e-02, ...,
-1.41201005e-03, 5.61865531e-02, 9.56140757e-02],
[-1.20172866e-01, -8.49731490e-02, 1.03288785e-01, ...,
-7.98950121e-02, -4.49094810e-02, 1.39608949e-01],
...,
[-7.61528537e-02, 3.20506580e-02, -2.50871405e-02, ...,
-7.38966092e-02, -3.21425265e-04, 1.36282891e-01],
[-1.44565761e-01, -4.32569981e-02, 2.04333365e-02, ...,
-6.63700625e-02, -3.81267555e-02, -2.30036005e-02],
[ 2.67012902e-02, -5.37795648e-02, -4.23380360e-02, ...,
-5.63229509e-02, -5.12197688e-02, -3.12134009e-02]],
[[-7.43731558e-02, -3.62176038e-02, 5.10122851e-02, ...,
-3.51174027e-02, 3.87258045e-02, 2.27168292e-01],
[ 1.16739683e-01, -2.23928560e-02, -5.57525223e-03, ...,
-1.02788061e-01, -4.01335116e-03, 8.06760862e-02],
[-1.64217681e-01, 8.26110918e-05, -4.12965678e-02, ...,
7.52085214e-03, -6.78892136e-02, -1.75480649e-01],
...,
[ 8.83450210e-02, 5.34156477e-03, -7.77265951e-02, ...,
1.17669202e-01, -4.10092101e-02, -1.01568244e-01],
[ 8.69038403e-02, -1.88934822e-02, -1.17780408e-02, ...,
-7.16696978e-02, -9.30670425e-02, 1.23171233e-01],
[-2.29017869e-01, 6.00396749e-03, 1.90518517e-02, ...,
-1.41594961e-01, 2.73972610e-03, -2.70640224e-01]],
[[-2.36580167e-02, -9.35083553e-02, 1.94609556e-02, ...,
7.69268163e-03, -1.22647034e-02, 9.45108756e-02],
[ 9.35060829e-02, -2.52102762e-02, 3.71520706e-02, ...,
-9.34997201e-02, 8.46573431e-03, -5.98869994e-02],
[-3.28883454e-02, -2.59795431e-02, -9.67503898e-03, ...,
-3.68000045e-02, 1.75331309e-02, 9.37586650e-03],
...,
[-4.09095921e-03, -5.14208898e-02, -7.73036852e-02, ...,
-9.90556255e-02, 2.79615112e-02, -2.30321847e-02],
[ 1.59659505e-01, -1.82742421e-02, 5.14865806e-03, ...,
-7.09384605e-02, 2.56752246e-03, -4.10889760e-02],
[-6.88848943e-02, -7.74879605e-02, -3.57778445e-02, ...,
2.12062951e-02, -1.10647090e-01, -8.52660984e-02]]],
[[[-2.46268973e-01, 3.85357179e-02, -4.06927988e-02, ...,
3.74751054e-02, -9.25076082e-02, -1.29397631e-01],
[ 7.63643160e-02, -6.98825344e-02, 2.06106082e-02, ...,
9.54258293e-02, 1.68073755e-02, 1.37710047e-03],
[ 1.05941750e-01, -4.16890271e-02, 1.94569994e-02, ...,
1.56132326e-01, 1.72235724e-02, -4.77363281e-02],
...,
[-1.09409325e-01, 5.61159924e-02, -1.12776212e-01, ...,
-1.19272105e-01, -3.68169770e-02, 3.49496678e-02],
[-8.11009202e-03, 3.49640884e-02, -6.70032576e-02, ...,
-6.32542968e-02, 1.50984516e-02, -1.72102273e-01],
[ 8.10830444e-02, 3.48706394e-02, -3.07865832e-02, ...,
1.29607707e-01, -5.81163615e-02, 1.03756957e-01]],
[[-8.04790780e-02, -8.27980638e-02, -1.05671294e-01, ...,
-7.03366250e-02, 4.25918251e-02, 1.75646335e-01],
[-6.02779612e-02, 3.28739360e-02, -3.63444462e-02, ...,
-1.12059116e-01, -6.25926554e-02, 1.38955384e-01],
[-2.38637440e-02, -7.79598951e-02, 5.62047698e-02, ...,
2.28353627e-02, -2.37480737e-03, 5.32163940e-02],
...,
[-3.15365382e-02, -3.44829215e-03, -6.26939237e-02, ...,
-1.04815885e-01, 1.66660943e-03, -1.46268513e-02],
[ 7.27598695e-03, 9.55763459e-03, -4.28966023e-02, ...,
-4.77525927e-02, 1.72682526e-03, 7.72892591e-03],
[ 1.72994703e-01, -6.68177381e-02, 5.22779580e-03, ...,
4.69456948e-02, -1.33015667e-04, -2.45814715e-02]],
[[-9.57137123e-02, 7.37990811e-03, -5.69204502e-02, ...,
2.32768655e-01, -3.73396128e-02, 5.58465421e-02],
[-1.84253957e-02, -2.04128996e-02, 1.02314362e-02, ...,
1.96217880e-01, -8.14723894e-02, -1.19025171e-01],
[ 6.46807626e-02, 5.45739336e-03, 3.43166254e-02, ...,
1.29366472e-01, -2.66727451e-02, 6.07023761e-03],
...,
[ 5.86102270e-02, 6.51542982e-03, -6.93500265e-02, ...,
-1.63897485e-01, -6.41653836e-02, -3.82056534e-02],
[ 3.81928682e-02, -6.24898523e-02, -1.06647953e-01, ...,
-4.72425707e-02, 3.92622454e-03, 2.31024548e-02],
[ 1.48619562e-01, -6.53460994e-02, -3.18123102e-02, ...,
1.43899456e-01, 2.96681821e-02, 8.48675147e-02]]]],
shape=(3, 3, 64, 64), dtype=float32),
array([ 2.4550732e-03, -2.5643468e-02, -4.2420231e-02, 5.1648982e-02,
1.5094544e-02, -6.1443243e-03, 3.4930601e-03, -5.4257452e-02,
1.7708238e-02, 2.5660655e-02, 2.4773186e-02, -1.1195319e-02,
-3.9667740e-02, -3.3047036e-06, -4.8360862e-03, -2.7014380e-02,
-1.5237678e-02, -1.6341314e-02, -3.7341344e-03, -2.7565246e-02,
5.5930126e-03, 6.0194261e-02, -4.5430576e-03, 5.3550620e-02,
1.5513954e-02, -1.8517613e-02, -7.9647377e-03, -2.3070274e-02,
2.3695156e-02, -3.1956777e-02, -1.0008575e-02, 8.9354776e-03,
-1.7382259e-02, -2.1166353e-02, 2.7923085e-02, 2.8792849e-02,
-3.2289915e-02, 2.5300624e-02, 1.9008113e-02, 2.5793573e-02,
-2.9155096e-02, 2.2237476e-02, 3.1690761e-02, -3.3192068e-02,
-1.3683923e-02, -3.7145395e-02, -9.2115039e-03, 7.1950540e-02,
-3.7705472e-03, -2.8528892e-02, -6.3294269e-02, -1.5230717e-02,
-2.2907145e-02, -4.2952418e-02, -3.5951860e-02, 9.5042344e-03,
2.7293764e-02, 4.1889604e-02, 4.0394571e-02, 1.2372457e-02,
-3.5087023e-02, -3.0876681e-02, -2.3205196e-02, -2.1498492e-02],
dtype=float32),
array([[ 0.05708771, -0.06234324, -0.13463886, ..., -0.03414698,
0.06119309, -0.02520272],
[-0.02455455, 0.07509242, 0.07041636, ..., 0.09118779,
0.08221027, -0.08331598],
[-0.08261199, -0.04301149, -0.00914894, ..., -0.09411783,
-0.08325154, -0.00983006],
...,
[ 0.21591546, -0.0254946 , 0.07165627, ..., 0.19633831,
-0.1346423 , 0.08037651],
[ 0.04433294, 0.0122591 , -0.02097172, ..., 0.09993152,
-0.07984101, 0.00207634],
[-0.12268885, -0.03293889, 0.09699838, ..., 0.03583752,
-0.06936302, -0.05591767]], shape=(576, 64), dtype=float32),
array([ 0.01418659, -0.03349337, -0.06969396, -0.00489685, 0.02487434,
-0.01691333, -0.01663163, -0.06567444, -0.01818012, -0.01410949,
0.06172884, 0.04483001, -0.00748309, -0.01995968, -0.00552262,
0.08356399, 0.08863807, -0.00358188, -0.00643371, -0.07654649,
-0.01133776, -0.03647338, -0.04736697, -0.0670499 , -0.00294621,
0.0354447 , 0.04969401, -0.03378887, 0.09706189, 0.00067311,
-0.05088674, -0.06142539, -0.02328563, -0.03018828, -0.04188383,
0.00369699, -0.02562218, -0.04497053, 0.00346254, -0.03978704,
-0.054894 , -0.04106856, -0.02364217, -0.03416612, -0.04195937,
-0.03278247, 0.03881833, -0.04748299, 0.0227492 , 0.07585373,
0.07963715, 0.0418564 , 0.00727706, -0.08144415, 0.00310017,
0.01189614, 0.05605335, 0.10206967, 0.01744832, -0.02168363,
-0.00761675, -0.01551267, -0.0139414 , 0.03232931], dtype=float32),
array([[-5.16570956e-02, 1.82396144e-01, 1.86560884e-01,
7.24766180e-02, 5.22688739e-02, 8.71893018e-02,
-1.00037508e-01, 1.42719477e-01, -2.42041737e-01,
-3.66162956e-01],
[ 1.78905219e-01, -1.64840147e-01, -1.78555652e-01,
1.50711602e-02, 1.54459298e-01, 2.16973662e-01,
1.54395908e-01, -3.16675343e-02, 2.50614882e-01,
-7.16972718e-05],
[-1.62268102e-01, -3.70794564e-01, 1.01160772e-01,
-1.99901089e-01, -1.34590287e-02, -2.16006264e-01,
-2.27270067e-01, -1.05627939e-01, -2.30283424e-01,
9.00198817e-02],
[ 2.54945874e-01, -2.30583414e-01, -3.24434698e-01,
-4.64541800e-02, -1.14861749e-01, -3.08331758e-01,
2.65793532e-01, -4.29244488e-01, 2.64962077e-01,
1.47684097e-01],
[ 1.89971998e-01, -1.26820460e-01, 6.93956316e-02,
-3.10436875e-01, 2.72932410e-01, 1.20527938e-01,
4.83166575e-02, 2.14841276e-01, 2.21598551e-01,
1.94459423e-01],
[-1.01403847e-01, -2.39520654e-01, -1.85177177e-01,
-1.54134959e-01, -3.11129242e-01, 2.15920851e-01,
-1.25315174e-01, 1.47254735e-01, 1.28699071e-03,
-2.78044730e-01],
[ 5.26516140e-02, -2.37085894e-01, 1.80669680e-01,
-4.07854229e-01, 1.72871143e-01, -2.57931389e-02,
1.11917555e-02, -3.11586857e-01, -2.29372889e-01,
-1.52407646e-01],
[ 9.74876359e-02, -1.75573707e-01, 1.52443699e-03,
1.82614475e-01, -2.97166348e-01, 4.68500704e-02,
3.72588541e-03, -6.81607006e-03, -2.63306499e-01,
1.41870350e-01],
[ 2.06628874e-01, -2.04366744e-01, -3.61576468e-01,
-2.70182043e-01, 1.07519627e-01, -2.00734496e-01,
-3.27197790e-01, -1.40033245e-01, -3.32248926e-01,
-1.36627182e-01],
[-7.75504624e-04, 1.02751479e-01, -1.46293089e-01,
-2.88504511e-01, 1.03341743e-01, 2.85819829e-01,
-5.94628938e-02, 1.61925808e-01, -1.34175748e-01,
2.79708207e-01],
[-3.29341471e-01, 5.26261181e-02, -2.42110178e-01,
2.05533952e-02, 8.31093490e-02, -3.27978402e-01,
-3.25569212e-01, 1.07423492e-01, -3.75838876e-01,
-9.64015797e-02],
[ 1.03486061e-01, 2.48279169e-01, -1.97293177e-01,
-9.47673097e-02, 1.32593885e-01, -3.43340188e-02,
-3.10092688e-01, 3.17143261e-01, -1.05908655e-01,
1.29661053e-01],
[ 1.75281182e-01, 1.93607450e-01, 5.06153442e-02,
1.52850389e-01, 3.47571522e-01, -6.66841120e-02,
1.84733897e-01, -2.60123700e-01, -2.34570295e-01,
1.70061544e-01],
[-7.41054788e-02, -1.56462654e-01, 1.31585538e-01,
2.76431471e-01, -1.19654231e-01, -3.38970162e-02,
1.70878172e-02, 2.54455715e-01, -2.68775731e-01,
1.01187686e-02],
[ 1.57435015e-01, 6.48213401e-02, -1.64385252e-02,
5.30404598e-02, -2.11611688e-01, -1.32774353e-01,
-2.27555603e-01, -3.14288706e-01, -1.99659958e-01,
8.57193395e-03],
[-4.45573151e-01, -1.75526902e-01, -2.18004555e-01,
3.72956097e-02, 1.33197144e-01, -1.57117695e-02,
-3.10838133e-01, 8.82661194e-02, -3.29450928e-02,
4.38649580e-02],
[-7.97981471e-02, -1.14686064e-01, 2.68548913e-02,
1.54006585e-01, -2.44160309e-01, -2.73801506e-01,
-1.88529760e-01, -2.64297366e-01, 2.27251992e-01,
-1.08027138e-01],
[-7.94948190e-02, -3.89029942e-02, 2.16617256e-01,
-1.28654197e-01, -5.58328303e-03, -1.61028028e-01,
-3.77498806e-01, -1.45756423e-01, 1.44786075e-01,
1.27807304e-01],
[-1.85091898e-01, 1.32755846e-01, -1.11977933e-02,
-2.23677963e-01, -1.85212448e-01, -2.93326676e-01,
-5.74105941e-02, 1.98006853e-01, -1.65121943e-01,
2.15950832e-01],
[-1.87861118e-02, -9.10503343e-02, -3.07070632e-02,
-1.54559031e-01, -1.01825327e-01, -1.42532066e-01,
-1.91016406e-01, -1.35894924e-01, -2.96260208e-01,
2.77623355e-01],
[ 1.66226804e-01, 1.63931400e-01, -2.25374743e-01,
-1.02028362e-02, 2.01440603e-01, 1.10866509e-01,
5.08405641e-02, -1.63740918e-01, 2.22709626e-01,
2.42358387e-01],
[ 1.41369656e-01, -1.40989572e-01, 2.89542705e-01,
-1.51578054e-01, -9.12421048e-02, -1.67799622e-01,
-4.33283627e-01, -9.54575464e-03, -6.95173163e-03,
2.52436459e-01],
[ 2.85074830e-01, 5.97847905e-03, -2.32044280e-01,
1.00973941e-01, -3.39538038e-01, -2.47360080e-01,
-1.88782454e-01, 5.32565312e-03, -8.85380581e-02,
7.35052899e-02],
[ 1.78755045e-01, 1.47358239e-01, -1.23605870e-01,
8.52901489e-02, 6.41780673e-03, -1.41452134e-01,
-2.91708142e-01, 2.11124029e-02, 4.82854396e-02,
2.42745891e-01],
[ 1.85737789e-01, 1.87957972e-01, 2.53847670e-02,
-2.70438284e-01, 2.61514783e-01, 2.35514328e-01,
2.39865452e-01, 2.03062356e-01, 1.16571411e-01,
-1.47464484e-01],
[ 1.38842374e-01, -2.01592192e-01, 3.09788495e-01,
2.54305005e-01, -6.27078786e-02, -3.04349035e-01,
-2.23497748e-01, 5.78192845e-02, 1.97714180e-01,
4.28988263e-02],
[ 1.78229138e-01, -3.21614921e-01, -4.55862820e-01,
2.52224684e-01, 3.56726684e-02, 3.06170166e-01,
2.47444674e-01, -4.52093869e-01, 1.44500747e-01,
1.26036644e-01],
[ 1.94521084e-01, -2.03747347e-01, 2.21435785e-01,
6.18621632e-02, 1.60888791e-01, 2.47784197e-01,
2.60732651e-01, 2.75516063e-01, -4.30957926e-03,
-1.68438897e-01],
[-1.17162541e-01, 2.38961384e-01, -2.22619280e-01,
1.67741671e-01, -3.38397384e-01, 1.80762962e-01,
-8.22924450e-02, -1.92300335e-01, 1.82826534e-01,
-3.36759090e-01],
[-3.54503095e-01, 9.51806828e-03, -1.26831204e-01,
1.11947395e-01, 1.86633497e-01, -1.78340208e-02,
-2.74095953e-01, -2.73597576e-02, -7.75732845e-02,
2.32511282e-01],
[ 2.31229320e-01, -2.29861483e-01, -2.50747025e-01,
1.54900640e-01, -2.51247019e-01, -9.20872092e-02,
-4.38264571e-02, -6.01947010e-02, -3.70690674e-02,
5.72180841e-03],
[ 1.29668742e-01, 1.67938948e-01, 2.71530509e-01,
9.87362117e-02, -1.11610256e-02, -9.49455947e-02,
-7.03838691e-02, -1.90909311e-01, -9.61497501e-02,
-2.22085848e-01],
[-4.54211887e-03, -2.98274577e-01, -1.97932228e-01,
-1.71578214e-01, 1.96355850e-01, 1.17920609e-02,
-1.43543676e-01, 3.69117647e-01, -6.83321804e-02,
6.39851242e-02],
[ 2.66058117e-01, -2.21391246e-01, -6.51206821e-02,
2.48825580e-01, 2.70042233e-02, 7.69974142e-02,
7.67013654e-02, -1.87259927e-01, -1.68538824e-01,
2.20064402e-01],
[ 2.01051891e-01, 1.33354038e-01, 1.32612005e-01,
1.67933568e-01, -2.69793123e-01, 1.72866434e-01,
1.52714968e-01, 6.39446527e-02, -4.15093489e-02,
-2.70082891e-01],
[-1.14600575e-02, 1.41058400e-01, 2.51689047e-01,
4.42901673e-03, -1.33357257e-01, 2.12791517e-01,
-1.77473668e-02, 2.59607017e-01, 1.19011775e-01,
-3.02698433e-01],
[-1.32166475e-01, 1.66911885e-01, 2.87316162e-02,
-8.29543024e-02, -1.10832443e-02, -4.63532470e-02,
1.49000511e-01, -1.87354848e-01, -3.10452342e-01,
1.52855620e-01],
[ 8.17507803e-02, -1.13676637e-01, -6.13319129e-02,
-2.76785463e-01, -1.56265944e-01, -2.50247747e-01,
1.38710305e-01, 2.38066941e-01, -1.74162373e-01,
-2.39063397e-01],
[ 1.96594775e-01, -2.74036169e-01, -8.38667229e-02,
2.29640026e-02, 2.19124749e-01, -2.34084919e-01,
2.11351439e-01, -3.24498564e-01, 6.66174963e-02,
-4.03303616e-02],
[ 2.75130093e-01, -1.52934849e-01, -1.46691397e-01,
-2.88175851e-01, -4.06669348e-01, 5.65760508e-02,
4.17029597e-02, 1.14666998e-01, -2.04257533e-01,
-1.91317797e-01],
[-1.70486182e-01, 8.48507062e-02, -1.51479170e-01,
1.15432711e-02, 1.90412775e-01, 1.97004318e-01,
-1.39617264e-01, -2.48450980e-01, -2.18839403e-02,
4.33079228e-02],
[ 1.66243352e-02, -3.51673663e-02, -1.18699059e-01,
1.64179951e-01, 2.07475826e-01, 6.36996981e-03,
-2.97149181e-01, 6.26099994e-03, -1.93936884e-01,
1.19045772e-01],
[-2.34536842e-01, -2.92002052e-01, -8.14765990e-02,
-1.23654418e-01, 8.89466256e-02, -1.06820785e-01,
5.86572997e-02, 1.38415217e-01, 6.30669817e-02,
2.67673820e-01],
[ 2.24965349e-01, -1.34028852e-01, -3.06553453e-01,
-1.28137171e-01, -3.60821225e-02, -2.40799621e-01,
3.14842671e-01, -1.33109808e-01, -3.24275881e-01,
2.14633778e-01],
[ 1.45587787e-01, -4.68877405e-02, -2.01387659e-01,
8.34583715e-02, 1.81788400e-01, -2.07731719e-04,
1.24553092e-01, -1.99418649e-01, 1.07602052e-01,
8.28800648e-02],
[-5.74691221e-02, 1.35082810e-03, -2.02241033e-01,
2.48051181e-01, -9.00799632e-02, 2.64533609e-01,
-3.13308567e-01, 1.03611529e-01, -2.79019535e-01,
4.96034771e-02],
[ 7.48019144e-02, -3.07263732e-01, -2.96288401e-01,
1.08347572e-01, 2.02868655e-01, 9.72235128e-02,
1.35799825e-01, -3.53739820e-02, 2.12206632e-01,
3.01186055e-01],
[ 1.82164505e-01, -2.44176805e-01, -1.58329472e-01,
-2.26171255e-01, -1.32030800e-01, 2.37682581e-01,
-1.37904966e-02, 1.37954220e-01, 3.48205939e-02,
-2.58509349e-02],
[-3.58668923e-01, -7.95790367e-03, -2.96002775e-01,
1.31813928e-01, -3.23758125e-01, 3.19456100e-01,
-2.04832956e-01, -3.81792724e-01, -2.01389521e-01,
7.19458684e-02],
[-1.24115035e-01, 2.55270958e-01, -9.58794057e-02,
-2.26069152e-01, 2.96746284e-01, -6.19703978e-02,
1.93892360e-01, -3.35892051e-01, 2.40316659e-01,
-2.16129005e-01],
[ 1.41175613e-01, 1.98802382e-01, 6.41119182e-02,
-2.44437769e-01, 5.93550839e-02, -2.23857194e-01,
-4.21661466e-01, 1.99799687e-01, 1.51105165e-01,
-1.09498642e-01],
[-9.89007577e-02, 1.57155842e-01, 1.59081534e-01,
3.08500469e-01, -1.02180459e-01, 2.31799990e-01,
2.73230374e-01, 2.31182531e-01, 2.42998704e-01,
-3.09382677e-01],
[ 1.45925164e-01, -8.92660543e-02, 1.45201921e-01,
2.41680801e-01, 7.16382116e-02, -1.47748515e-01,
1.98164552e-01, 3.24054360e-01, -2.22253263e-01,
-2.05984369e-01],
[-7.03814849e-02, 1.83393851e-01, -1.79929554e-01,
-2.37406537e-01, -7.27819130e-02, -8.79087448e-02,
1.96527746e-02, -2.14605957e-01, 1.10800490e-01,
-2.10513532e-01],
[-8.59806463e-02, 1.76193550e-01, 1.73069030e-01,
-9.20611024e-02, -2.36101240e-01, 1.21646918e-01,
2.78342962e-01, -8.48453045e-02, 1.64842486e-01,
1.68510452e-01],
[-1.52394146e-01, -2.50590563e-01, 2.70139366e-01,
3.09226215e-01, -6.77545518e-02, 1.16110161e-01,
1.21644542e-01, 2.77483672e-01, -1.75407201e-01,
-3.19735438e-01],
[-1.29462734e-01, -2.01208040e-01, 1.04181226e-02,
-1.35632172e-01, 2.17418656e-01, 9.36727524e-02,
-1.58911109e-01, -1.90711603e-01, 2.38178521e-01,
5.52048944e-02],
[-8.05972591e-02, 1.48052126e-01, -3.75475466e-01,
3.79764438e-02, 1.24343233e-02, -6.51395991e-02,
8.06608424e-02, 2.07118884e-01, 3.84490967e-01,
7.20644146e-02],
[ 1.92045867e-01, -3.11668962e-02, 3.13478976e-01,
9.90396645e-03, -2.58267522e-01, -2.90753283e-02,
1.91320434e-01, -1.75890923e-01, 3.04402053e-01,
-1.63461879e-01],
[ 4.34739590e-02, -1.88699216e-01, 1.56138837e-01,
2.07077473e-01, -2.81859279e-01, 1.08558983e-01,
-2.40348920e-01, -1.11143710e-02, -7.54054040e-02,
-1.86252594e-02],
[-5.11724763e-02, 3.47667307e-01, -2.60294467e-01,
-9.30366665e-02, 1.42463103e-01, -2.65860975e-01,
-9.15907025e-02, 2.44312793e-01, -4.10301000e-01,
1.86489657e-01],
[-1.29526660e-01, 1.33382335e-01, 1.53816730e-01,
-3.20706904e-01, -7.38243386e-02, -1.43649042e-01,
2.27065638e-01, -1.99850172e-01, 1.17242811e-02,
-3.99768539e-02],
[ 2.21772581e-01, -1.56109646e-01, -1.24776006e-01,
1.75048411e-01, -3.22463602e-01, 8.00449178e-02,
-1.92207679e-01, 4.06748056e-02, -2.07004949e-01,
1.96689054e-01],
[-1.71793252e-01, -1.18684553e-01, 2.27854028e-01,
2.94921517e-01, -2.24115521e-01, -3.83285373e-01,
-1.00005284e-01, -1.30508587e-01, -3.95162925e-02,
-3.12202543e-01]], dtype=float32),
array([-0.08606234, 0.04817877, -0.04294467, 0.01732274, 0.02866258,
-0.0510773 , -0.08001343, 0.02867183, 0.09399392, 0.00924843],
dtype=float32)]
8. Modell evaluieren¶
[10]:
# Evaluation
test_loss, test_acc = model.evaluate(x_test, y_test)
print(f"Testgenauigkeit: {test_acc}")
print(f"Test Loss: {test_loss}")
313/313 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step - accuracy: 0.9913 - loss: 0.0293
Testgenauigkeit: 0.9912999868392944
Test Loss: 0.02932381071150303
9. Beispielhafte Vorhersage¶
[11]:
predictions = model.predict(x_test)
plt.imshow(x_test[0].reshape(28, 28), cmap="gray")
plt.title(f"Vorhergesagte Klasse: {predictions[0].argmax()}")
plt.show()
313/313 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step