修正import matplotlib.pyplot as plt show_train_history(train_history)

原程式碼:import matplotlib.pyplot as plt  

show_train_history(train_history)  

出現錯誤:KeyError Traceback (most recent call last)

Cell In[51], line 2
      1 import matplotlib.pyplot as plt  
----> 2 show_train_history(train_history)

Cell In[50], line 5, in show_train_history(train_history)
      3 fig.set_size_inches(16, 6)
      4 plt.subplot(121)
----> 5 plt.plot(train_history.history["acc"])
      6 plt.plot(train_history.history["val_acc"])
      7 plt.title("Train History")

KeyError: 'acc'

修正成這樣就好了
import matplotlib.pyplot as plt

def show_train_history(history):
    fig, ax = plt.subplots(figsize=(16, 6))

    # Accessing the history data from the History object
    plt.plot(history.history["accuracy"], label='Training Accuracy')
    plt.plot(history.history["val_accuracy"], label='Validation Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.title("Train History")
    plt.show()

# Assuming you have a model `model` and training data `x_train`, `y_train`, etc.
history = model.fit(x_train, y_train, batch_size=64, epochs=10, validation_data=(x_test, y_test), verbose=1)

# Call the function to display the plots
show_train_history(history)

评论

此博客中的热门博文

修正input_img_data = np.random.random((1, 150, 150, 3)) * 20 + 128.

緣起