8. Training HOG-based AU detectors#

written by Tiankang Xie

In the tutorial we will demonstrate how to train the HOG-based AU models as described in our paper. The tutorial is split into 3 parts, where the first part demonstrates how to extract hog features from the dataset, and the second part demonstrates how to use the extracted hogs to perform statistical learning, the third part will be to demonstrate how to test the trained models with additional test data

Part 1: Extracting HOGs and Landmarks#

### To speed up training the HOGs, we will first try to extract the HOG and landmark features from image paths using py-feat
import sys
import torch
import torch.nn as nn
import math
from feat.utils import set_torch_device
import torch.nn.functional as F
from copy import deepcopy
import numpy as np
from skimage import draw
import matplotlib.pyplot as plt
from PIL import Image
from torchvision import transforms
from feat.utils.image_operations import extract_face_from_landmarks

from PIL import Image
from itertools import product
import os 
from torchvision.transforms import Compose, Normalize, Grayscale

import pandas as pd
from tqdm import tqdm
from feat import Detector

from joblib import delayed, Parallel
from torchvision.utils import save_image
from torchvision.io import read_image, read_video
from torch.utils.data import Dataset
from feat.transforms import Rescale
import glob
from skimage.feature import hog
import pickle
from torch.utils.data import DataLoader
from feat.data import (
    Fex,
    ImageDataset,
    VideoDataset,
    _inverse_face_transform,
    _inverse_landmark_transform,
)
import glob
from feat.utils.image_operations import (
    extract_face_from_landmarks,
    extract_face_from_bbox,
    convert_image_to_tensor,
    BBox,
)
au_df = pd.read_csv('/home/tiankang/AU_Dataset/EmotioNet/EmotioNet_master.csv', index_col=0)
# This is the file of the AU annotations.
# It should look like something: |filepath|AU1|AU2|AU3..., where the first filepath column indicates the filepath of the input image 
au_df.head() 
filepath AU1 AU2 AU4 AU5 AU6 AU9 AU12 AU17 AU20 AU25 AU26 AU43
0 /Storage/Data/EmotioNet/imgs/N_0000000001_0000... 0.0 999.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0
1 /Storage/Data/EmotioNet/imgs/N_0000000001_0000... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0
2 /Storage/Data/EmotioNet/imgs/N_0000000001_0000... 1.0 0.0 1.0 1.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0
3 /Storage/Data/EmotioNet/imgs/N_0000000001_0000... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0
4 /Storage/Data/EmotioNet/imgs/N_0000000001_0000... 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0
detector = Detector(face_model='retinaface',emotion_model='resmasknet', landmark_model="mobilefacenet", au_model='svm')

SAVE_HOG_DIR = '/Storage/Projects/pyfeat_testing/HOGFeatures/MyHOGFeatures/'
input_file_list = au_df['filepath'].to_list()

if not os.path.exists(SAVE_HOG_DIR):
    os.makedirs(SAVE_HOG_DIR)

# a list of all path figures

data_loader = DataLoader(
    ImageDataset(
        input_file_list,
        output_size=256,
        preserve_aspect_ratio=True,
        padding=True,
    ),
    num_workers=0,
    batch_size=1,
    pin_memory=False,
    shuffle=False,
)

def _batch_hog(frames, landmarks):
    """
    Helper function used in batch processing hog features

    Args:
        frames: a batch of frames
        landmarks: a list of list of detected landmarks

    Returns:
        hog_features: a numpy array of hog features for each detected landmark
        landmarks: updated landmarks
    """
    frames = convert_image_to_tensor(frames, img_type="float32")

    hog_features = []
    hog_images = []
    new_landmark_frames = []
    for i, frame_landmark in enumerate(landmarks):
        if len(frame_landmark) != 0:
            new_landmarks_faces = []
            for j in range(len(frame_landmark)):
                convex_hull, new_landmark = extract_face_from_landmarks(
                    frame=frames[i],
                    landmarks=frame_landmark[j],
                    face_size=112,
                )
                fd, hog_image=hog(
                        transforms.ToPILImage()(convex_hull[0] / 255.0),
                        orientations=8,
                        pixels_per_cell=(8, 8),
                        cells_per_block=(2, 2),
                        visualize=True,
                        channel_axis=-1,
                    )
                                
                hog_features.append(fd)
                hog_images.append(hog_image)
                new_landmarks_faces.append(new_landmark)
            
            new_landmark_frames.append(new_landmarks_faces)
        else:
            hog_features.append(
                np.zeros((1, 5408))
            )  # LC: Need to confirm this size is fixed.
            new_landmark_frames.append([np.zeros((68, 2))])

    return (hog_features, hog_images, new_landmark_frames)


for cc, batch_data in enumerate(tqdm(data_loader)):
    # Iterate through all the images in dataloader to get the hog feature and landmark feature
    try:
        faces = detector.detect_faces(
                batch_data["Image"],
                threshold=0.5)
                
        landmarks = detector.detect_landmarks(
            batch_data["Image"],
            detected_faces=faces)

        hog_features, hog_images, new_landmark_frames = _batch_hog(batch_data["Image"], landmarks)

        for i in range(len(hog_features)):
            with open(SAVE_HOG_DIR+batch_data['FileNames'][i].split('/')[-1].split('.')[0]+'.pkl', 'wb') as fp:
                pickle.dump((hog_features[i], new_landmark_frames[i][0]), fp)

    except:
        print('something went wrong with reading the image')
        continue;
  1%|          | 218/21938 [00:12<19:32, 18.52it/s]
something went wrong with reading the image
  1%|▏         | 302/21938 [00:16<20:07, 17.92it/s]
something went wrong with reading the image
  2%|▏         | 337/21938 [00:18<19:27, 18.50it/s]WARNING:root:Warning: NO FACE is detected
  3%|▎         | 605/21938 [00:33<20:56, 16.98it/s]
something went wrong with reading the image
  4%|▍         | 987/21938 [00:53<21:58, 15.89it/s]
something went wrong with reading the image
  6%|▌         | 1208/21938 [01:04<19:44, 17.50it/s]
something went wrong with reading the image
  6%|▌         | 1258/21938 [01:07<18:34, 18.55it/s]
something went wrong with reading the image
  6%|▋         | 1372/21938 [01:13<21:06, 16.24it/s]
something went wrong with reading the image
  7%|▋         | 1585/21938 [01:24<17:33, 19.32it/s]
something went wrong with reading the image
  9%|▉         | 2066/21938 [01:49<17:49, 18.59it/s]
something went wrong with reading the image
 11%|█         | 2318/21938 [02:02<17:28, 18.71it/s]
something went wrong with reading the image
 11%|█         | 2436/21938 [02:09<17:02, 19.07it/s]
something went wrong with reading the image
 12%|█▏        | 2578/21938 [02:16<16:13, 19.88it/s]
something went wrong with reading the image
 12%|█▏        | 2704/21938 [02:23<19:09, 16.73it/s]
something went wrong with reading the image
 14%|█▍        | 3161/21938 [02:46<16:18, 19.19it/s]
something went wrong with reading the image
 21%|██        | 4534/21938 [03:57<15:30, 18.70it/s]
something went wrong with reading the image
 21%|██▏       | 4689/21938 [04:05<14:58, 19.20it/s]WARNING:root:Warning: NO FACE is detected
 22%|██▏       | 4865/21938 [04:14<15:39, 18.17it/s]
something went wrong with reading the image
 23%|██▎       | 5013/21938 [04:22<16:13, 17.38it/s]
something went wrong with reading the image
 25%|██▍       | 5401/21938 [04:42<15:46, 17.47it/s]WARNING:root:Warning: NO FACE is detected
 25%|██▍       | 5483/21938 [04:47<14:43, 18.63it/s]
something went wrong with reading the image
 25%|██▌       | 5573/21938 [04:51<15:09, 17.99it/s]
something went wrong with reading the image
 26%|██▌       | 5757/21938 [05:01<14:06, 19.12it/s]
something went wrong with reading the image
 28%|██▊       | 6191/21938 [05:23<14:05, 18.62it/s]
something went wrong with reading the image
 29%|██▊       | 6292/21938 [05:28<14:03, 18.55it/s]
something went wrong with reading the image
 32%|███▏      | 6940/21938 [06:02<12:33, 19.90it/s]WARNING:root:Warning: NO FACE is detected
 32%|███▏      | 6984/21938 [06:04<13:07, 18.98it/s]
something went wrong with reading the image
 32%|███▏      | 7106/21938 [06:11<14:20, 17.24it/s]
something went wrong with reading the image
 34%|███▍      | 7463/21938 [06:29<12:13, 19.74it/s]
something went wrong with reading the image
 35%|███▍      | 7577/21938 [06:35<12:26, 19.23it/s]
something went wrong with reading the image
 37%|███▋      | 8151/21938 [07:04<12:09, 18.90it/s]
something went wrong with reading the image
 40%|███▉      | 8678/21938 [07:32<13:01, 16.96it/s]
something went wrong with reading the image
 40%|████      | 8796/21938 [07:38<11:36, 18.87it/s]
something went wrong with reading the image
 42%|████▏     | 9148/21938 [07:56<12:48, 16.65it/s]
something went wrong with reading the image
 44%|████▎     | 9569/21938 [08:17<11:14, 18.33it/s]
something went wrong with reading the image
 45%|████▍     | 9841/21938 [08:31<10:32, 19.12it/s]
something went wrong with reading the image
 45%|████▌     | 9926/21938 [08:35<10:13, 19.59it/s]
something went wrong with reading the image
 46%|████▌     | 9984/21938 [08:39<10:56, 18.21it/s]
something went wrong with reading the image
 50%|█████     | 10987/21938 [09:30<10:08, 17.99it/s]
something went wrong with reading the image
 52%|█████▏    | 11455/21938 [09:54<09:31, 18.34it/s]
something went wrong with reading the image
 54%|█████▎    | 11777/21938 [10:10<08:56, 18.94it/s]
something went wrong with reading the image
 54%|█████▍    | 11811/21938 [10:12<08:59, 18.77it/s]
something went wrong with reading the image
 55%|█████▍    | 12008/21938 [10:22<08:22, 19.76it/s]
something went wrong with reading the image
 59%|█████▉    | 12969/21938 [11:11<08:02, 18.60it/s]
something went wrong with reading the image
 59%|█████▉    | 13008/21938 [11:13<08:00, 18.57it/s]
something went wrong with reading the image
 60%|█████▉    | 13136/21938 [11:19<07:25, 19.78it/s]WARNING:root:Warning: NO FACE is detected
 63%|██████▎   | 13714/21938 [11:49<07:22, 18.58it/s]
something went wrong with reading the image
 63%|██████▎   | 13855/21938 [11:56<06:54, 19.52it/s]
something went wrong with reading the image
 63%|██████▎   | 13862/21938 [11:56<07:40, 17.54it/s]
something went wrong with reading the image
 64%|██████▎   | 13959/21938 [12:01<06:40, 19.94it/s]
something went wrong with reading the image
 65%|██████▍   | 14161/21938 [12:11<07:00, 18.49it/s]
something went wrong with reading the image
 65%|██████▌   | 14359/21938 [12:22<06:51, 18.40it/s]
something went wrong with reading the image
 66%|██████▌   | 14458/21938 [12:27<06:47, 18.35it/s]
something went wrong with reading the image
 67%|██████▋   | 14776/21938 [12:43<06:12, 19.22it/s]
something went wrong with reading the image
 68%|██████▊   | 14821/21938 [12:46<06:07, 19.38it/s]
something went wrong with reading the image
 68%|██████▊   | 14848/21938 [12:47<06:08, 19.22it/s]
something went wrong with reading the image
 68%|██████▊   | 14871/21938 [12:48<06:10, 19.08it/s]
something went wrong with reading the image
 69%|██████▉   | 15191/21938 [13:05<06:30, 17.26it/s]
something went wrong with reading the image
 71%|███████   | 15623/21938 [13:27<05:38, 18.63it/s]
something went wrong with reading the image
 72%|███████▏  | 15834/21938 [13:38<05:25, 18.74it/s]
something went wrong with reading the image
 75%|███████▌  | 16542/21938 [14:15<04:42, 19.11it/s]WARNING:root:Warning: NO FACE is detected
 77%|███████▋  | 16801/21938 [14:29<04:26, 19.28it/s]WARNING:root:Warning: NO FACE is detected
 77%|███████▋  | 16834/21938 [14:30<04:14, 20.03it/s]
something went wrong with reading the image
 77%|███████▋  | 16858/21938 [14:31<04:18, 19.66it/s]
something went wrong with reading the image
 79%|███████▊  | 17269/21938 [14:53<04:12, 18.48it/s]
something went wrong with reading the image
 80%|███████▉  | 17451/21938 [15:03<04:16, 17.51it/s]
something went wrong with reading the image
 85%|████████▌ | 18663/21938 [16:06<02:40, 20.40it/s]WARNING:root:Warning: NO FACE is detected
 85%|████████▌ | 18739/21938 [16:10<02:54, 18.37it/s]
something went wrong with reading the image
 86%|████████▋ | 18940/21938 [16:20<02:41, 18.54it/s]
something went wrong with reading the image
 87%|████████▋ | 19045/21938 [16:26<02:41, 17.86it/s]
something went wrong with reading the image
 87%|████████▋ | 19089/21938 [16:28<02:32, 18.62it/s]
something went wrong with reading the image
 90%|█████████ | 19816/21938 [17:06<02:01, 17.48it/s]
something went wrong with reading the image
 91%|█████████ | 19957/21938 [17:13<01:42, 19.40it/s]WARNING:root:Warning: NO FACE is detected
 95%|█████████▍| 20830/21938 [17:59<00:57, 19.12it/s]
something went wrong with reading the image
 95%|█████████▌| 20904/21938 [18:02<00:54, 19.10it/s]
something went wrong with reading the image
 97%|█████████▋| 21318/21938 [18:24<00:31, 19.44it/s]
something went wrong with reading the image
100%|██████████| 21938/21938 [18:56<00:00, 19.31it/s]

Part 2: Conduct Dimension Reduction on HOG#

import pandas as pd
import numpy as np
from sklearn.metrics import classification_report
from sklearn.svm import LinearSVC, SVC
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.model_selection import GridSearchCV, RandomizedSearchCV
from tqdm import tqdm 
from sklearn.metrics import f1_score
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import make_classification
from sklearn.linear_model import LogisticRegression
import os
import pickle
from tqdm import tqdm
import matplotlib.pyplot as plt
from sklearn.multioutput import MultiOutputClassifier
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import f1_score
def compile_dataset(saved_hog_path, au_df):
    """compile the saved hog and landmark features 
    Args:
        saved_hog_path: where you saved the HOGs in the last section
        au_df: a pandas dataframe that contains filepaths and AU annotations

    Returns:
        np.stack(hog_feats): a numpy array of hog features 
        np.stack(land_feats): a numpy array of landmarks 
        au_df.iloc[all_valid, :]: a pandas df of AU annotations 
    """
        
    all_valid = [] # Which images are valid images detectable by Py-Feat?
    hog_feats = [] # Aggregated HOG Features
    land_feats = [] # Aggregated Landmark Features
    all_o_filename = au_df['filepath'].to_list() # Filenames in the annotation file
    all_hog_fp = [saved_hog_path+os.path.basename(op).split('.')[0]+'.pkl' for op in all_o_filename]

    for ji in range(len(all_hog_fp)):
        with open(all_hog_fp[ji], 'rb') as fp:
            hog_feat, new_lands = pickle.load(fp)
        if (len(hog_feat) == 5408) and (new_lands.shape[0] == 68) and (new_lands.shape[1] == 2): # Restrict to valid HOGs
            all_valid.append(True)
            hog_feats.append(hog_feat)
            land_feats.append(new_lands)
        else:
            all_valid.append(False)

    return np.stack(hog_feats), np.stack(land_feats), au_df.iloc[all_valid, :]
trained_hogs, trained_land, labels_df = compile_dataset(saved_hog_path='/Storage/Projects/pyfeat_testing/HOGFeatures/MyHOGFeatures/',
                                                        au_df=au_df)
trained_land = trained_land.reshape(trained_hogs.shape[0], -1)
print(trained_hogs.shape)
print(trained_land.shape)
print(labels_df.shape)
(21929, 5408)
(21929, 136)
(21929, 13)
# Note that it is possible to use only upper face features / only lower face features / full face features to predict AUs
# For demonstration purporses we will only be using full face features in the later section.
scaler_upper = StandardScaler()
scaler_lower = StandardScaler()
scaler_full = StandardScaler()

pca_full = PCA(n_components=0.95)
pca_upper = PCA(n_components=0.98)
pca_lower = PCA(n_components=0.98)
hog_data_upper = trained_hogs.copy()
hog_data_upper[:, 2414:] = 0 # Restrict to upper feature 
hog_data_upper_std = scaler_upper.fit_transform(hog_data_upper)
hog_data_upper_transformed = pca_upper.fit_transform(hog_data_upper_std)
del hog_data_upper, hog_data_upper_std

hog_data_lower = trained_hogs.copy()
hog_data_lower[:, 0:2221] = 0 # Restrict to lower feature 
hog_data_lower_std = scaler_lower.fit_transform(hog_data_lower)
hog_data_lower_transformed = pca_lower.fit_transform(hog_data_lower_std)
del hog_data_lower, hog_data_lower_std

hog_data_full_std = scaler_full.fit_transform(trained_hogs)
hog_data_full_transformed = pca_full.fit_transform(hog_data_full_std)
del hog_data_full_std

Part 3: Prepare training data & label, and conduct Machine Learning#

import pandas as pd
import numpy as np
from sklearn.metrics import classification_report
from sklearn.svm import LinearSVC, SVC
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.model_selection import GridSearchCV, RandomizedSearchCV
from tqdm import tqdm 
from sklearn.metrics import f1_score
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import make_classification
from sklearn.linear_model import LogisticRegression
import os
import pickle
from tqdm import tqdm
import matplotlib.pyplot as plt
from sklearn.multioutput import MultiOutputClassifier
from sklearn.model_selection import StratifiedKFold, KFold
from sklearn.metrics import f1_score, precision_recall_fscore_support, accuracy_score

from imblearn.over_sampling import RandomOverSampler, SMOTE, ADASYN
from imblearn.under_sampling import RandomUnderSampler
from imblearn.pipeline import Pipeline
from imblearn import FunctionSampler
# In this tutorial we will be using full face features as an example. Feel free to use upper / lower
x_features = np.concatenate([hog_data_full_transformed, trained_land], 1)
# In this tutorial we will be using only AU1
y_features = labels_df['AU1'].to_numpy()
def func_samp(x, y):
    return x, y

def _run_and_testData(clf, x_features, y_features, sampling_method='under', cv_n_splits=5):
    """
    This function runs n-fold cross-validation on the dataset
    """
    if sampling_method == 'under':
        ros = RandomUnderSampler(random_state=0)
    elif sampling_method == 'smote':
        ros1 = SMOTE(random_state=0, sampling_strategy=0.60)
        ros2 = RandomUnderSampler(random_state=0, sampling_strategy=0.5)
        ros = Pipeline(steps=[('o', ros1), ('u', ros2)])
    else:
        ros = FunctionSampler(func=func_samp)

    valid_train_idx = np.where(np.logical_or(y_features == 0, y_features==1))[0]
    x_training_valid, y_training_valid = x_features[valid_train_idx, :], y_features[valid_train_idx]

    skf = StratifiedKFold(n_splits=cv_n_splits, random_state=1, shuffle=True)

    for i, (train_index, val_index) in enumerate(skf.split(X=x_training_valid, y=y_training_valid)):
    
        xx_train, xx_val = x_training_valid[train_index], x_training_valid[val_index]
        yy_train, yy_val = y_training_valid[train_index], y_training_valid[val_index]
        
        xx_bal, yy_bal = ros.fit_resample(xx_train, yy_train)
        clf.fit(xx_bal, yy_bal)
        fitted_pred = clf.predict(xx_bal)
        prec, rec, fscore, supp = precision_recall_fscore_support(y_true=yy_bal, y_pred=fitted_pred, average='binary')
        print('training score:', prec, rec, fscore)

        preds = clf.predict(xx_val)
        prec, rec, fscore, supp = precision_recall_fscore_support(y_true=yy_val, y_pred=preds, average='binary')
        acc = accuracy_score(y_true=yy_val, y_pred=preds)
        print('validation score:', prec, rec, fscore, acc)
model_AU1 = LinearSVC(penalty='l2', C=5e-5, loss='squared_hinge', tol=2e-4, max_iter=2000)
_run_and_testData(clf=model_AU1,
                  x_features=x_features, y_features=y_features, sampling_method='under', cv_n_splits=5)
training score: 0.8882575757575758 0.9036608863198459 0.8958930276981855
validation score: 0.19434306569343066 0.8223938223938224 0.31439114391143913 0.7734146341463415
training score: 0.878645343367827 0.9006750241080038 0.8895238095238095
validation score: 0.180073126142596 0.7576923076923077 0.29098966026587886 0.7658536585365854
training score: 0.8709073900841908 0.8977820636451301 0.8841405508072175
validation score: 0.19941634241245138 0.7884615384615384 0.3183229813664596 0.7858536585365854
training score: 0.8782771535580525 0.9036608863198459 0.8907882241215574
validation score: 0.1894150417827298 0.7876447876447876 0.3053892215568862 0.7736033178824103
training score: 0.8711484593837535 0.8988439306358381 0.8847795163584637
validation score: 0.18973418881759854 0.7992277992277992 0.3066666666666667 0.7716516223469139

Part 4. Validate the results on benchmark data#

## We will be only using a small subset of DisfaPlus for testing, as an example
disfaP_toy_df = pd.read_csv(
    '/home/tiankang/src/feat/dev/disfaP_toy.csv', index_col=0
)
# This is the csv file that contains filepath and AU labels
disfaP_toy_df.head()
Unnamed: 0 aligned_landmark subject task frame AU1 AU2 AU4 AU5 AU6 AU9 AU12 AU15 AU17 AU20 AU25 AU26 filepath
33543 5885 [ 42.96360929 43.31080323 53.19024402 38.23... SN010 Y_SadDescribed_TrailNo_1 34 1 1 2 1 0 0 0 0 0 0 0 0 /Storage/Data/DISFAPlusDataset/Images/SN010/Y_...
53616 2508 [ 43.97696623 51.17963593 53.27121144 45.34... SN025 Y_FearDescribed_TrailNo_1 74 0 0 0 0 0 0 0 0 0 0 0 0 /Storage/Data/DISFAPlusDataset/Images/SN025/Y_...
46436 1824 [ 47.81582882 53.83295079 57.0445239 46.36... SN001 C1_AU26_TrailNo_2 94 0 0 0 0 0 0 0 0 0 0 0 0 /Storage/Data/DISFAPlusDataset/Images/SN001/C1...
26573 7612 [ 47.98804535 48.04893453 57.13419986 42.65... SN027 A3_AU1_2_TrailNo_1 99 0 0 0 0 0 0 0 0 0 0 0 0 /Storage/Data/DISFAPlusDataset/Images/SN027/A3...
37144 3215 [ 45.78620683 48.71779324 55.53607682 42.76... SN007 A7_AU5z_TrailNo_2 32 0 0 0 0 0 0 0 0 0 0 0 0 /Storage/Data/DISFAPlusDataset/Images/SN007/A7...
## Again use the pyfeat modules to get the HOG and Landmark features
detector = Detector(face_model='retinaface',emotion_model='resmasknet', landmark_model="mobilefacenet", au_model='svm')

SAVE_HOG_DIR = '/Storage/Projects/pyfeat_testing/HOGFeatures/MyHOGTestFeatures/'
input_file_list = disfaP_toy_df['filepath'].to_list()

if not os.path.exists(SAVE_HOG_DIR):
    os.makedirs(SAVE_HOG_DIR)

# a list of all path figures

data_loader = DataLoader(
    ImageDataset(
        input_file_list,
        output_size=256,
        preserve_aspect_ratio=True,
        padding=True,
    ),
    num_workers=0,
    batch_size=1,
    pin_memory=False,
    shuffle=False,
)

def _batch_hog(frames, landmarks):
    """
    Helper function used in batch processing hog features

    Args:
        frames: a batch of frames
        landmarks: a list of list of detected landmarks

    Returns:
        hog_features: a numpy array of hog features for each detected landmark
        landmarks: updated landmarks
    """
    frames = convert_image_to_tensor(frames, img_type="float32")

    hog_features = []
    hog_images = []
    new_landmark_frames = []
    for i, frame_landmark in enumerate(landmarks):
        if len(frame_landmark) != 0:
            new_landmarks_faces = []
            for j in range(len(frame_landmark)):
                convex_hull, new_landmark = extract_face_from_landmarks(
                    frame=frames[i],
                    landmarks=frame_landmark[j],
                    face_size=112,
                )
                fd, hog_image=hog(
                        transforms.ToPILImage()(convex_hull[0] / 255.0),
                        orientations=8,
                        pixels_per_cell=(8, 8),
                        cells_per_block=(2, 2),
                        visualize=True,
                        channel_axis=-1,
                    )
                                
                hog_features.append(fd)
                hog_images.append(hog_image)
                new_landmarks_faces.append(new_landmark)
            
            new_landmark_frames.append(new_landmarks_faces)
        else:
            hog_features.append(
                np.zeros((1, 5408))
            )  # LC: Need to confirm this size is fixed.
            new_landmark_frames.append([np.zeros((68, 2))])

    return (hog_features, hog_images, new_landmark_frames)


for cc, batch_data in enumerate(tqdm(data_loader)):
    # Iterate through all the images in dataloader to get the hog feature and landmark feature
    try:
        faces = detector.detect_faces(
                batch_data["Image"],
                threshold=0.5)
                
        landmarks = detector.detect_landmarks(
            batch_data["Image"],
            detected_faces=faces)

        hog_features, hog_images, new_landmark_frames = _batch_hog(batch_data["Image"], landmarks)

        for i in range(len(hog_features)):
            with open(SAVE_HOG_DIR+batch_data['FileNames'][i].split('/')[-1].split('.')[0]+'.pkl', 'wb') as fp:
                pickle.dump((hog_features[i], new_landmark_frames[i][0]), fp)

    except:
        print('something went wrong with reading the image')
        continue;
100%|██████████| 100/100 [00:06<00:00, 15.38it/s]
# Obtain the HOG & Land Features for test dataset
test_hogs, test_land, test_labels_df = compile_dataset(saved_hog_path='/Storage/Projects/pyfeat_testing/HOGFeatures/MyHOGTestFeatures/',
                                                        au_df=disfaP_toy_df)
test_land = test_land.reshape(test_hogs.shape[0], -1)
# Apply dimension reduction on HOG
test_data_full_std = scaler_full.transform(test_hogs)
test_data_full_transformed = pca_full.transform(test_data_full_std)
test_feature = np.concatenate([test_data_full_transformed, test_land], 1)
# Predict and calculate score
AU1_predicted = model_AU1.predict(test_feature)
AU1_labels = np.where(test_labels_df['AU1'] > 0, 1, 0) # Note that DISFAP uses labels that range from 0 to 5. 
# We binarize this label to 0 and 1.

print("tested F1 score: ", f1_score(y_true=AU1_labels, y_pred=AU1_predicted))
tested F1 score:  0.41558441558441556