Files
2026-01-09 10:28:44 +11:00

101 lines
3.7 KiB
Python

# Copyright 2022 David Scripka. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Imports
import re
from tqdm import tqdm
import numpy as np
from typing import List
# Define metric utility functions specific to the wakeword detection use-case
def get_false_positives(scores: List, threshold: float, grouping_window: int = 50):
"""
Counts the number of false-positives based on a list of scores and a specified threshold.
Args:
scores (List): A list of predicted scores, between 0 and 1
threshold (float): The threshold to use to determine false-positive predictions
grouping_window (int: The size (in number of frames) for grouping scores above
the threshold into a single false positive for counting
Returns:
int: The number of false positive predictions in the list of scores
"""
bin_pred = np.array(scores) >= threshold
bin_pred_string = ''.join(["1" if i else "0" for i in bin_pred])
transitions = list(re.finditer("01", bin_pred_string))
n = grouping_window
for t in transitions:
if bin_pred[t.end()] != 0:
bin_pred[t.end():t.end() + min(len(transitions) - t.end(), n)] = [0]*min(len(transitions) - t.end(), n)
return sum(bin_pred)
def generate_roc_curve_fprs(
scores: list,
n_points: int = 25,
time_per_prediction: float = .08,
**kwargs
):
"""
Generates the false positive rate (fpr) per hour for the given predictions
over a range of score thresholds. Assumes that all predictions should be less than the threshold,
else the prediction is a false positive.
Args:
scores (List): A list of predicted scores, between 0 and 1
n_points (int): The number of points to use when calculating false positive rates
time_per_prediction (float): The time (in seconds) that each prediction represents
kwargs (dict): Any other keyword arguments to pass to the `get_false_positives` function
Returns:
list: A list of false positive rates per hour at different score threshold levels
"""
# Determine total time
total_hours = time_per_prediction*len(scores)/3600 # convert to hours
# Calculate true positive rate
fprs = []
for threshold in tqdm(np.linspace(0.01, 0.99, num=n_points)):
fpr = get_false_positives(scores, threshold=threshold, **kwargs)
fprs.append(fpr/total_hours)
return fprs
def generate_roc_curve_tprs(
scores: list,
n_points: int = 25
):
"""
Generates the true positive rate (true accept rate) for the given predictions
over a range score thresholds. Assumes that all predictions are supposed to be equal to 1.
Args:
scores (list): A list of scores for each prediction
Returns:
list: A list of true positive rates at different score threshold levels
"""
tprs = []
for threshold in tqdm(np.linspace(0.01, 0.99, num=n_points)):
tprs.append(sum(scores >= threshold)/len(scores))
return tprs