webentwicklung-frage-antwort-db.com.de

wie man ein frühes Anhalten im Tensorflow implementiert

def train():
# Model
model = Model()

# Loss, Optimizer
global_step = tf.Variable(1, dtype=tf.int32, trainable=False, name='global_step')
loss_fn = model.loss()
optimizer = tf.train.AdamOptimizer(learning_rate=TrainConfig.LR).minimize(loss_fn, global_step=global_step)

# Summaries
summary_op = summaries(model, loss_fn)

with tf.Session(config=TrainConfig.session_conf) as sess:

    # Initialized, Load state
    sess.run(tf.global_variables_initializer())
    model.load_state(sess, TrainConfig.CKPT_PATH)

    writer = tf.summary.FileWriter(TrainConfig.GRAPH_PATH, sess.graph)

    # Input source
    data = Data(TrainConfig.DATA_PATH)

    loss = Diff()
    for step in xrange(global_step.eval(), TrainConfig.FINAL_STEP):

            mixed_wav, src1_wav, src2_wav, _ = data.next_wavs(TrainConfig.SECONDS, TrainConfig.NUM_WAVFILE, step)

            mixed_spec = to_spectrogram(mixed_wav)
            mixed_mag = get_magnitude(mixed_spec)

            src1_spec, src2_spec = to_spectrogram(src1_wav), to_spectrogram(src2_wav)
            src1_mag, src2_mag = get_magnitude(src1_spec), get_magnitude(src2_spec)

            src1_batch, _ = model.spec_to_batch(src1_mag)
            src2_batch, _ = model.spec_to_batch(src2_mag)
            mixed_batch, _ = model.spec_to_batch(mixed_mag)

            # Initializae our callback.
            #early_stopping_cb = EarlyStoppingCallback(val_acc_thresh=0.5)


            l, _, summary = sess.run([loss_fn, optimizer, summary_op],
                                     feed_dict={model.x_mixed: mixed_batch, model.y_src1: src1_batch,
                                                model.y_src2: src2_batch})

            loss.update(l)
            print('step-{}\td_loss={:2.2f}\tloss={}'.format(step, loss.diff * 100, loss.value))

            writer.add_summary(summary, global_step=step)

            # Save state
            if step % TrainConfig.CKPT_STEP == 0:
                tf.train.Saver().save(sess, TrainConfig.CKPT_PATH + '/checkpoint', global_step=step)

    writer.close()

Ich habe diesen neuronalen Netzwerkcode, der Musik von einer Stimme in einer WAV-Datei trennt. Wie kann ich einen Frühstoppalgorithmus einführen, um den Zugabschnitt anzuhalten? Ich sehe ein Projekt, das über einen ValidationMonitor spricht. Kann mir jemand helfen?

10
Valerio

Hier ist meine Implementierung des frühen Stopps, den Sie anpassen können:

Das vorzeitige Anhalten kann in bestimmten Phasen des Trainingsprozesses angewendet werden, z. B. am Ende jeder Epoche. Speziell; in meinem Fall; Ich überwache den Testverlust (Validierungsverlust) bei jeder Epoche und nachdem sich der Testverlust nach 20 Epochen (self.require_improvement= 20) Nicht verbessert hat, wird das Training unterbrochen.

Sie können die maximalen Epochen auf 10000 oder 20000 oder was immer Sie wollen (self.max_epochs = 10000) Einstellen.

  self.require_improvement= 20
  self.max_epochs = 10000

Hier ist meine Trainingsfunktion, bei der ich das frühe Stoppen nutze:

def zug (selbst):

# training data
    train_input = self.Normalize(self.x_train)
    train_output = self.y_train.copy()            
#===============
    save_sess=self.sess # this used to compare the result of previous sess with actual one
# ===============
  #costs history :
    costs = []
    costs_inter=[]
# =================
  #for early stopping :
    best_cost=1000000 
    stop = False
    last_improvement=0
# ================
    n_samples = train_input.shape[0] # size of the training set
# ===============
   #train the mini_batches model using the early stopping criteria
    Epoch = 0
    while Epoch < self.max_epochs and stop == False:
        #train the model on the traning set by mini batches
        #suffle then split the training set to mini-batches of size self.batch_size
        seq =list(range(n_samples))
        random.shuffle(seq)
        mini_batches = [
            seq[k:k+self.batch_size]
            for k in range(0,n_samples, self.batch_size)
        ]

        avg_cost = 0. # The average cost of mini_batches
        step= 0

        for sample in mini_batches:

            batch_x = x_train.iloc[sample, :]
            batch_y =train_output.iloc[sample, :]
            batch_y = np.array(batch_y).flatten()

            feed_dict={self.X: batch_x,self.Y:batch_y, self.is_train:True}

            _, cost,acc=self.sess.run([self.train_step, self.loss_, self.accuracy_],feed_dict=feed_dict)
            avg_cost += cost *len(sample)/n_samples 
            print('Epoch[{}] step [{}] train -- loss : {}, accuracy : {}'.format(Epoch,step, avg_cost, acc))
            step += 100

        #cost history since the last best cost
        costs_inter.append(avg_cost)

        #early stopping based on the validation set/ max_steps_without_decrease of the loss value : require_improvement
        if avg_cost < best_cost:
            save_sess= self.sess # save session
            best_cost = avg_cost
            costs +=costs_inter # costs history of the validatio set
            last_improvement = 0
            costs_inter= []
        else:
            last_improvement +=1
        if last_improvement > self.require_improvement:
            print("No improvement found during the ( self.require_improvement) last iterations, stopping optimization.")
            # Break out from the loop.
            stop = True
            self.sess=save_sess # restore session with the best cost

        ## Run validation after every Epoch : 
        print('---------------------------------------------------------')
        self.y_validation = np.array(self.y_validation).flatten()
        loss_valid, acc_valid = self.sess.run([self.loss_,self.accuracy_], 
                                              feed_dict={self.X: self.x_validation, self.Y: self.y_validation,self.is_train: True})
        print("Epoch: {0}, validation loss: {1:.2f}, validation accuracy: {2:.01%}".format(Epoch + 1, loss_valid, acc_valid))
        print('---------------------------------------------------------')

        Epoch +=1

Wir können den wichtigen Code hier fortsetzen:

def train(self):
  ...
      #costs history :
        costs = []
        costs_inter=[]
      #for early stopping :
        best_cost=1000000 
        stop = False
        last_improvement=0
       #train the mini_batches model using the early stopping criteria
        Epoch = 0
        while Epoch < self.max_epochs and stop == False:
            ...
            for sample in mini_batches:
            ...                   
            #cost history since the last best cost
            costs_inter.append(avg_cost)

            #early stopping based on the validation set/ max_steps_without_decrease of the loss value : require_improvement
            if avg_cost < best_cost:
                save_sess= self.sess # save session
                best_cost = avg_cost
                costs +=costs_inter # costs history of the validatio set
                last_improvement = 0
                costs_inter= []
            else:
                last_improvement +=1
            if last_improvement > self.require_improvement:
                print("No improvement found during the ( self.require_improvement) last iterations, stopping optimization.")
                # Break out from the loop.
                stop = True
                self.sess=save_sess # restore session with the best cost
            ...
            Epoch +=1

Hoffe es wird jemandem helfen :).

7
DINA TAKLIT

ValidationMonitor wird als veraltet markiert. es wird nicht empfohlen. Sie können es aber trotzdem verwenden. Hier ist ein Beispiel, wie man eins erstellt:

    validation_monitor = monitors.ValidationMonitor(
        input_fn=functools.partial(input_fn, subset="evaluation"),
        eval_steps=128,
        every_n_steps=88,
        early_stopping_metric="accuracy",
        early_stopping_rounds = 1000
    )

und du kannst es selbst umsetzen, hier meine Implementierung:

          if (loss_value < self.best_loss):
            self.stopping_step = 0
            self.best_loss = loss_value
          else:
            self.stopping_step += 1
          if self.stopping_step >= FLAGS.early_stopping_step:
            self.should_stop = True
            print("Early stopping is trigger at step: {} loss:{}".format(global_step,loss_value))
            run_context.request_stop()
7
scott huang

Seit der TensorFlow-Version r1.10 stehen in early_stopping.py frühe Stopp-Hooks für die Estimator-API zur Verfügung (siehe github ).

Zum Beispiel tf.contrib.estimator.stop_if_no_decrease_hook (siehe docs )

0
mdh