12. Benchmarking Action Unit detector using data#

written by Tiankang Xie

In the tutorial we will demonstrate how to evaluate pyfeat AU algorithms with evaluation data

import sys
from copy import deepcopy
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from torchvision import transforms
from feat.utils.image_operations import extract_face_from_landmarks
from PIL import Image
import os 
import pandas as pd
from tqdm import tqdm
from feat import Detector
import glob
from skimage.feature import hog
import pickle
from feat.data import imageLoader_DISFAPlus
from sklearn.metrics import f1_score

Provide the path for the

  1. raw dataset.

  2. where to save the validation results

You can request it from http://mohammadmahoor.com/disfa/

save_result_dir = '/Storage/Projects/pyfeat_testing/Data_Eshin/au_test/'
data_dir = "/Storage/Data/DISFAPlusDataset/"
disfa_file_data = imageLoader_DISFAPlus(data_dir=data_dir) # We provide a special dataloader for disfaPlus

1. Test XGB model#

# Define the function just to extract landmarks from images
detector = Detector(face_model='retinaface',emotion_model='resmasknet', landmark_model="mobilefacenet", au_model='xgb', device='cpu')
/home/tiankang/anaconda3/envs/py39/lib/python3.9/site-packages/torchvision/models/_utils.py:135: UserWarning: Using 'backbone_name' as positional parameter(s) is deprecated since 0.13 and may be removed in the future. Please use keyword parameter(s) instead.
  warnings.warn(
/home/tiankang/anaconda3/envs/py39/lib/python3.9/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=ResNet18_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet18_Weights.DEFAULT` to get the most up-to-date weights.
  warnings.warn(msg)
aus_classified = []
for i, imgs in enumerate(tqdm(disfa_file_data)):
    faces = detector.detect_faces(imgs['Image'])
    landmarks = detector.detect_landmarks(imgs['Image'], detected_faces=faces)
    poses = detector.detect_facepose(imgs['Image'])
    aus = detector.detect_aus(imgs['Image'], landmarks)
    aus_classified.append(aus)

with open(save_result_dir+'xgb_au_predictions.pkl', 'wb') as fp:
    pickle.dump(aus_classified, fp)

Calculate F1 score Metrics#

with open(save_result_dir+'xgb_au_predictions.pkl', 'rb') as fp:
    aus_classified = pickle.load(fp)
predictions = np.squeeze(np.stack(aus_classified[0]))
labels = aus_classified[1]
predicted_aus = ["AU1","AU2","AU4","AU5","AU6","AU7","AU9","AU10","AU11","AU12","AU14","AU15",
                    "AU17","AU20","AU23","AU24","AU25","AU26","AU28", "AU43"]
testing_aus = ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU9', 'AU12', 'AU15', 'AU17', 'AU20', 'AU25', 'AU26']
for i, auname in enumerate(testing_aus):
    print('======')
    index_arr = predicted_aus.index(auname)
    print(auname, ',  f1 score: ', f1_score((labels[auname]>=2).astype(int), (predictions[:, index_arr]>=0.5).astype(int)))

2. Test SVM model#

detector = Detector(face_model='retinaface',emotion_model='resmasknet', landmark_model="mobilefacenet", au_model='svm', device='cpu')
aus_classified = []
for i, imgs in enumerate(tqdm(disfa_file_data)):
    faces = detector.detect_faces(imgs['Image'])
    landmarks = detector.detect_landmarks(imgs['Image'], detected_faces=faces)
    poses = detector.detect_facepose(imgs['Image'])
    aus = detector.detect_aus(imgs['Image'], landmarks)
    aus_classified.append(aus)

with open(save_result_dir+'svm_au_predictions.pkl', 'wb') as fp:
    pickle.dump(aus_classified, fp)

Calculate F1 score Metrics#

with open(save_result_dir+'svm_au_predictions.pkl', 'rb') as fp:
    aus_classified = pickle.load(fp)
predictions = np.squeeze(np.stack(aus_classified[0]))
labels = aus_classified[1]
predicted_aus = ["AU1","AU2","AU4","AU5","AU6","AU7","AU9","AU10","AU11","AU12","AU14","AU15",
                    "AU17","AU20","AU23","AU24","AU25","AU26","AU28", "AU43"]
testing_aus = ['AU1', 'AU2', 'AU4', 'AU5', 'AU6', 'AU9', 'AU12', 'AU15', 'AU17', 'AU20', 'AU25', 'AU26']
for i, auname in enumerate(testing_aus):
    print('======')
    index_arr = predicted_aus.index(auname)
    print(auname, ' f1 score: ', f1_score((labels[auname]>=2).astype(int), (predictions[:, index_arr]>=0.5).astype(int)))