webentwicklung-frage-antwort-db.com.de

Verwendung der k-fachen Kreuzvalidierung in einem neuronalen Netzwerk

Wir schreiben eine kleine ANN, die 7000 Produkte anhand von 10 Eingabevariablen in 7 Klassen einteilen soll.

Um dies zu tun, müssen wir die k-fache Kreuzvalidierung verwenden, aber wir sind irgendwie verwirrt.

Wir haben diesen Auszug aus der Präsentationsfolie:

k-fold cross validation diagram

Was genau sind die Validierungs- und Testsätze?

Nach unserem Verständnis durchlaufen wir die 3 Trainingssätze und stellen die Gewichte ein (einzelne Epoche). Was machen wir dann mit der Validierung? Denn soweit ich weiß, wird der Testsatz verwendet, um den Fehler des Netzwerks zu ermitteln.

Was als nächstes passiert, verwirrt mich auch. Wann findet die Frequenzweiche statt?

Wenn es nicht zu viel verlangt, wäre eine Aufzählungsliste mit Schritten wünschenswert

29
Ortixx

Sie scheinen ein bisschen verwirrt zu sein (ich erinnere mich, dass ich es auch war), also werde ich die Dinge für Sie vereinfachen. ;)

Beispielszenario für ein neuronales Netzwerk

Wenn Sie eine Aufgabe erhalten, beispielsweise ein neuronales Netzwerk zu entwickeln, erhalten Sie häufig auch einen Beispieldatensatz, der für Schulungszwecke verwendet werden kann. Nehmen wir an, Sie trainieren ein einfaches neuronales Netzwerksystem Y = W · X, Wobei Y die Ausgabe ist, die aus der Berechnung des Skalarprodukts (·) des Gewichtsvektors W mit einer gegebenen Stichprobe berechnet wird Vektor X. Nun, der naive Weg, dies zu tun, würde darin bestehen, den gesamten Datensatz von beispielsweise 1000 Proben zu verwenden, um das neuronale Netzwerk zu trainieren. Unter der Annahme, dass das Training konvergiert und sich Ihre Gewichte stabilisieren, können Sie sicher sagen, dass Ihr Netzwerk die Trainingsdaten korrekt klassifiziert. Aber was passiert mit dem Netzwerk, wenn zuvor nicht sichtbare Daten angezeigt werden? Der Zweck solcher Systeme besteht eindeutig darin, andere Daten als die für das Training verwendeten verallgemeinern und richtig klassifizieren zu können.

Überanpassung erklärt

In jeder realen Situation sind jedoch zuvor nicht angezeigte/neue Daten nur verfügbar, wenn Ihr neuronales Netzwerk in einer Produktionsumgebung (sozusagen) bereitgestellt wird. Aber da Sie es nicht ausreichend getestet haben, werden Sie wahrscheinlich eine schlechte Zeit haben. :) Das Phänomen, durch das jedes Lernsystem fast perfekt zu seinem Trainingssatz passt, aber ständig mit unsichtbaren Daten versagt, wird Überanpassung genannt.

Die drei Sätze

Hier kommen in der Validierung und Prüfung Teile des Algorithmus. Kehren wir zum ursprünglichen Datensatz von 1000 Proben zurück. Sie teilen es in drei Gruppen auf - Training , Validierung und Testen (Tr, Va und Te) - mit sorgfältig ausgewählten Proportionen. (80-10-10)% ist normalerweise ein guter Anteil, wobei:

  • Tr = 80%
  • Va = 10%
  • Te = 10%

Schulung und Validierung

Was nun passiert ist, dass das neuronale Netzwerk auf den Tr Satz trainiert wird und seine Gewichte korrekt aktualisiert werden. Der Validierungssatz Va wird dann verwendet, um den Klassifizierungsfehler E = M - Y Unter Verwendung der aus dem Training resultierenden Gewichte zu berechnen, wobei M der erwartete Ausgabevektor ist, der aus dem Validierungssatz entnommen wurde, und Y ist die berechnete Ausgabe, die sich aus der Klassifizierung ergibt (Y = W * X). Wenn der Fehler einen benutzerdefinierten Schwellenwert überschreitet, wird die gesamte Trainingsvalidierungs-Epoche wiederholt. Diese Trainingsphase endet, wenn der mit dem Validierungssatz berechnete Fehler als niedrig genug eingestuft wird.

Intelligentes Training

Hier ist es eine clevere List, zufällig auszuwählen, welche Samples für das Training und die Validierung aus der Gesamtmenge Tr + Va Bei jeder Epocheniteration verwendet werden sollen. Dies stellt sicher, dass das Netzwerk nicht über das Trainingsset passt.

Testen

Die Testmenge Te wird dann verwendet, um die Leistung des Netzwerks zu messen. Diese Daten sind perfekt für diesen Zweck, da sie während der Schulungs- und Validierungsphase nie verwendet wurden. Es handelt sich im Grunde genommen um eine kleine Menge bisher nicht sichtbarer Daten, die nachahmen sollen, was passieren würde, wenn das Netzwerk in der Produktionsumgebung bereitgestellt wird.

Die Leistung wird erneut anhand des Klassifizierungsfehlers wie oben erläutert gemessen. Die Leistung kann (oder sollte) auch in Form von Präzision und Rückruf gemessen werden, um zu wissen, wo und wie der Fehler auftritt, aber das ist das Thema für eine andere Frage und Antwort.

Quervalidierung

Wenn man diesen Mechanismus zum Testen der Trainingsvalidierung verstanden hat, kann man das Netzwerk gegen Überanpassung weiter stärken, indem man K-fache Kreuzvalidierung durchführt. Dies ist in gewisser Weise eine Weiterentwicklung des oben erläuterten intelligenten Tricks. Diese Technik beinhaltet das Durchführen von K Runden von Trainingsvalidierungstests an verschiedenen, nicht überlappenden, gleichmäßig proportionierten Tr, Va und Te setzt .

Mit k = 10 Teilen Sie Ihren Datensatz für jeden Wert von K in Tr+Va = 90% Und Te = 10% Auf und führen den Algorithmus aus, wobei Sie die Testleistung aufzeichnen.

k = 10
for i in 1:k
     # Select unique training and testing datasets
     KFoldTraining <-- subset(Data)
     KFoldTesting <-- subset(Data)

     # Train and record performance
     KFoldPerformance[i] <-- SmartTrain(KFoldTraining, KFoldTesting)

# Compute overall performance
TotalPerformance <-- ComputePerformance(KFoldPerformance)

Überanpassung gezeigt

Ich übernehme die weltberühmte Handlung aus wikipedia , um zu zeigen, wie der Validierungssatz hilft, eine Überanpassung zu verhindern. Der Trainingsfehler in Blau nimmt tendenziell mit zunehmender Anzahl von Epochen ab. Das Netzwerk versucht daher, genau auf den Trainingssatz abzustimmen. Der Validierungsfehler in rot folgt dagegen einem anderen, u-förmigen Profil. Das Minimum der Kurve ist, wenn das Training idealerweise gestoppt werden sollte, da dies der Punkt ist, an dem der Trainings- und Validierungsfehler am geringsten ist.

Overfitting reduced by validating neural network

Verweise

Weitere Hinweise dieses ausgezeichnete Buch vermittelt Ihnen sowohl fundierte Kenntnisse des maschinellen Lernens als auch einige Migräne. Es liegt an Ihnen, zu entscheiden, ob es sich lohnt. :)

81
JoErNanO
  1. Teilen Sie Ihre Daten in K nicht überlappende Falten. Lassen Sie jede Falte K eine gleiche Anzahl von Elementen aus jeder der m Klassen enthalten (geschichtete Kreuzvalidierung; wenn Sie 100 Elemente aus Klasse A und 50 aus Klasse B haben und eine zweifache Validierung durchführen, sollte jede Falte zufällig 50 Elemente enthalten von A und 25 von B).

    1. Für i in 1..k:

      • Bezeichne fold i die Testfalte
      • Bestimmen Sie eine der verbleibenden k-1-Faltungen der Validierungsfalte (dies kann entweder zufällig sein oder eine Funktion von i, spielt keine Rolle)
      • Bezeichnen Sie alle verbleibenden Falten als Trainingsfalte
      • Führen Sie eine Rastersuche nach allen freien Parametern (z. B. Lernrate, Anzahl der Neuronen in der verborgenen Schicht) durch, indem Sie an Ihren Trainingsdaten trainieren und an Ihren Validierungsdaten Datenverluste berechnen. Wählen Sie Parameter, um den Verlust zu minimieren
      • Verwenden Sie den Klassifikator mit den Gewinnparametern, um den Testverlust zu bewerten. Sammeln Sie Ergebnisse

Sie haben jetzt über alle Falten hinweg Gesamtergebnisse gesammelt. Dies ist Ihre endgültige Leistung. Wenn Sie dies in der Praxis anwenden möchten, verwenden Sie die besten Parameter aus der Rastersuche, um alle Daten zu trainieren.

5
Ben Allison