webentwicklung-frage-antwort-db.com.de

Wie erstellen Sie eine benutzerdefinierte Aktivierungsfunktion mit Keras?

Manchmal reichen die standardmäßigen Standardaktivierungen wie ReLU, tanh, softmax, ... und die fortgeschrittenen Aktivierungen wie LeakyReLU nicht aus. Und es könnte auch nicht in keras-contrib sein.

Wie erstellen Sie Ihre eigene Aktivierungsfunktion?

13
Martin Thoma

Dank an diese Github-Ausgabe von Ritchie Ng .

from keras.layers import Activation
from keras import backend as K
from keras.utils.generic_utils import get_custom_objects

def custom_activation(x):
    return (K.sigmoid(x) * 5) - 1

get_custom_objects().update({'custom_activation': Activation(custom_activation)})

model.add(Activation(custom_activation))

Bitte beachten Sie, dass Sie diese Funktion beim Speichern und Wiederherstellen des Modells importieren müssen. Siehe die Notiz von Keras-Contrib .

30
Martin Thoma

Etwas einfacher als Martin Thomas Antwort : Sie können einfach eine benutzerdefinierte elementweise Back-End-Funktion erstellen und diese als Parameter verwenden. Sie müssen diese Funktion noch importieren, bevor Sie Ihr Modell laden können.

from keras import backend as K

def custom_activation(x):
    return (K.sigmoid(x) * 5) - 1

model.add(Dense(32 , activation=custom_activation))
10
Eponymous

Angenommen, Sie möchten swish oder gelu zu Keras hinzufügen. Die vorherigen Methoden sind Nice-Inline-Einfügungen. Sie können sie jedoch auch in die Kerasaktivierungsfunktionen einfügen, so dass Sie Ihre benutzerdefinierte Funktion aufrufen, wie Sie ReLU aufrufen würden. Ich habe das mit Keras 2.2.2 getestet (jede Version 2). Hängen Sie an diese Datei $HOME/anaconda2/lib/python2.7/site-packages/keras/activations.py die Definition Ihrer benutzerdefinierten Funktion an (kann für Ihre Python- und Anaconda-Version unterschiedlich sein).

In keras intern:

$HOME/anaconda2/lib/python2.7/site-packages/keras/activations.py

def swish(x):
    return (K.sigmoid(beta * x) * alpha *x)

Dann in deiner Python-Datei:

$HOME/Documents/neural_nets.py

model = Sequential()
model.add(Activation('swish'))
0
Julien Nyambal