訂閱
糾錯(cuò)
加入自媒體

如何用Tensorflow框架構(gòu)建用于食品分類的機(jī)器學(xué)習(xí)模型?

這是數(shù)據(jù)框的視圖,

下一步就是制作一個(gè)對(duì)象,將圖片放入模型中。我們將練習(xí)tf.keras.preprocessing.image庫(kù)的ImageDataGenerator對(duì)象。使用此對(duì)象,我們將生成圖像批次。此外,我們可以擴(kuò)充我們的圖片以擴(kuò)大數(shù)據(jù)集的乘積。因?yàn)槲覀冞擴(kuò)展了這些圖片,我們進(jìn)一步設(shè)置了圖像增強(qiáng)技術(shù)的參數(shù)。

此外,因?yàn)槲覀儜?yīng)用了一個(gè)數(shù)據(jù)幀作為關(guān)于數(shù)據(jù)集的知識(shí),因此我們將使用flow_from_dataframe方法生成批處理并增強(qiáng)圖片。上面的代碼看起來(lái)像這樣from tensorflow.keras.preprocessing.image import ImageDataGenerator
# Create the ImageDataGenerator object
train_datagen = ImageDataGenerator(
   featurewise_center=True,
   featurewise_std_normalization=True,
   rotation_range=20,
   width_shift_range=0.2,
   height_shift_range=0.2,
   horizontal_flip=True,

val_datagen = ImageDataGenerator(
   featurewise_center=True,
   featurewise_std_normalization=True,
   rotation_range=20,
   width_shift_range=0.2,
   height_shift_range=0.2,
   horizontal_flip=True,

# Generate batches and augment the images
train_generator = train_datagen.flow_from_dataframe(
   df_train,
   directory='Food-5K/training/',
   x_col='filename',
   y_col='label',
   class_mode='binary',
   target_size=(224, 224),

val_generator = train_datagen.flow_from_dataframe(
   df_val,
   directory='Food-5K/validation/',
   x_col='filename',
   y_col='label',
   class_mode='binary',
   target_size=(224, 224),

步驟3:訓(xùn)練模型

決定好批次之后,我們可以通過(guò)遷移學(xué)習(xí)技術(shù)來(lái)訓(xùn)練模型。因?yàn)槲覀儜?yīng)用了這種方法,所以我們不需要從頭開(kāi)始執(zhí)行 CNN 架構(gòu)。相反,我們將使用當(dāng)前和以前預(yù)訓(xùn)練的架構(gòu)。我們將應(yīng)用 ResNet-50 作為我們新模型的脊椎。我們將生成輸入并根據(jù)類別數(shù)量調(diào)整 ResNet-50 的最后一個(gè)線性層ResNet-50。

構(gòu)建模型的代碼如下

from tensorflow.keras.a(chǎn)pplications import ResNet50
# Initialize the Pretrained Model
feature_extractor = ResNet50(weights='imagenet',
                            input_shape=(224, 224, 3),
                            include_top=False)
# Set this parameter to make sure it's not being trained
feature_extractor.trainable = False
# Set the input layer
input_ = tf.keras.Input(shape=(224, 224, 3))
# Set the feature extractor layer
x = feature_extractor(input_, training=False)
# Set the pooling layer
x = tf.keras.layers.GlobalAveragePooling2D()(x)
# Set the final layer with sigmoid activation function
output_ = tf.keras.layers.Dense(1, activation='sigmoid')(x)
# Create the new model object
model = tf.keras.Model(input_, output_)
# Compile it
model.compile(optimizer='adam',
            loss='binary_crossentropy',
            metrics=['accuracy'])
# Print The Summary of The Model
model.summary()

為了訓(xùn)練模型,我們采用擬合的方法來(lái)準(zhǔn)備模型。這是代碼,model.fit(train_generator, epochs=20, validation_data=val_generator)

步驟4:測(cè)試模型

在訓(xùn)練模型之后,現(xiàn)在讓我們?cè)跍y(cè)試數(shù)據(jù)集上檢查模型。在擴(kuò)展中,我們需要結(jié)合一個(gè)pillow庫(kù)來(lái)加載和調(diào)整圖片大小,以及 scikit-learn 來(lái)確定模型性能。我們將練習(xí)來(lái)自 scikit-learn 庫(kù)的分類報(bào)告,以生成關(guān)于模型執(zhí)行的報(bào)告。此外,我們會(huì)喜歡它的混淆矩陣。這是預(yù)測(cè)實(shí)驗(yàn)數(shù)據(jù)及其決策的代碼,from PIL import Image
from sklearn.metrics import classification_report, confusion_matrix
y_true = []
y_pred = []
for i in os.listdir('Food-5K/evaluation'):
   img = Image.open('Food-5K/evaluation/' + i)
   img = img.resize((224, 224))
   img = np.a(chǎn)rray(img)
   img = np.expand_dims(img, 0)
   
   y_true.a(chǎn)ppend(int(i.split('_')[0]))
   y_pred.a(chǎn)ppend(1 if model.predict(img) > 0.5 else 0)
   
print(classification_report(y_true, y_pred))
print()
print(confusion_matrix(y_true, y_pred))

從前面的內(nèi)容可以看出,該模型的性能已超過(guò)95%。因此,我們可以在建立圖像分類器 API 的情況下接受此模型。

步驟5:保存模型

如果你希望將模型用于后續(xù)使用或部署,你可以使用 save 方法保存模型,model.save('./resnet50_food_model')

如果你需要加載模型,你可以像這樣練習(xí)load_model方法,model = tf.keras.models.load_model('./resnet50_food_model')

下一步是什么

做得好!現(xiàn)在你已了解如何使用 TensorFlow 執(zhí)行遷移學(xué)習(xí)。我希望這項(xiàng)研究能鼓勵(lì)你,尤其是那些渴望在數(shù)據(jù)不足的情況下訓(xùn)練深度學(xué)習(xí)模型的人。

<上一頁(yè)  1  2  
聲明: 本文由入駐維科號(hào)的作者撰寫(xiě),觀點(diǎn)僅代表作者本人,不代表OFweek立場(chǎng)。如有侵權(quán)或其他問(wèn)題,請(qǐng)聯(lián)系舉報(bào)。

發(fā)表評(píng)論

0條評(píng)論,0人參與

請(qǐng)輸入評(píng)論內(nèi)容...

請(qǐng)輸入評(píng)論/評(píng)論長(zhǎng)度6~500個(gè)字

您提交的評(píng)論過(guò)于頻繁,請(qǐng)輸入驗(yàn)證碼繼續(xù)

  • 看不清,點(diǎn)擊換一張  刷新

暫無(wú)評(píng)論

暫無(wú)評(píng)論

    掃碼關(guān)注公眾號(hào)
    OFweek人工智能網(wǎng)
    獲取更多精彩內(nèi)容
    文章糾錯(cuò)
    x
    *文字標(biāo)題:
    *糾錯(cuò)內(nèi)容:
    聯(lián)系郵箱:
    *驗(yàn) 證 碼:

    粵公網(wǎng)安備 44030502002758號(hào)