Custom Evaluation Metric
import turboml as tb
import pandas as pd
from turboml.common import ModelMetricAggregateFunction
import math
Model Metric Aggregation function
Metric aggregate functions are used to add and compute any custom metric over model predictions and labels.
Overview of Metric Aggregate Functions
A metric aggregate function consists of the following lifecycle methods:
create_state()
: Initializes the aggregation state.accumulate(state, prediction, label)
: Updates the state based on input values.retract(state, prediction, label) (optional)
: Reverses the effect of previously accumulated values (useful in sliding windows or similar contexts).merge_states(state1, state2)
: Merges two states (for distributed computation).finish(state)
: Computes and returns the final metric value.
Steps to Define a Metric Aggregate Function
1. Define a Subclass
Create a subclass of ModelMetricAggregateFunction
and override its methods.
2. Implement Required Methods
At a minimum, one needs to implement:
- create_state
- accumulate
- finish
- merge_states
Example: Focal Loss Metric
Here’s an example of a custom focal loss metric function.
class FocalLoss(ModelMetricAggregateFunction):
def __init__(self):
super().__init__()
def create_state(self):
"""
Initialize the aggregation state.
Returns:
Any: A serializable object representing the initial state of the metric.
This can be a tuple, dictionary, or any other serializable data structure.
Note:
- The serialized size of the state should be less than 8MB to ensure
compatibility with distributed systems and to avoid exceeding storage
or transmission limits.
- Ensure the state is lightweight and efficiently encodable for optimal
performance.
"""
return (0.0, 0)
def _compute_focal_loss(self, prediction, label, gamma=2.0, alpha=0.25):
if prediction is None or label is None:
return None
pt = prediction if label == 1 else 1 - prediction
pt = max(min(pt, 1 - 1e-6), 1e-6)
return -alpha * ((1 - pt) ** gamma) * math.log(pt)
def accumulate(self, state, prediction, label):
"""
Update the state with a new prediction-target pair.
Args:
state (Any): The current aggregation state.
prediction (float): Predicted value.
label (float): Ground truth.
Returns:
Any: The updated aggregation state, maintaining the same format and requirements as `create_state`.
"""
loss_sum, weight_sum = state
focal_loss = self._compute_focal_loss(prediction, label)
if focal_loss is None:
return state
return loss_sum + focal_loss, weight_sum + 1
def finish(self, state):
"""
Compute the final metric value.
Args:
state (Any): Final state.
Returns:
float: The result.
"""
loss_sum, weight_sum = state
return 0 if weight_sum == 0 else loss_sum / weight_sum
def merge_states(self, state1, state2):
"""
Merge two states (for distributed computations).
Args:
state1 (Any): The first aggregation state.
state2 (Any): The second aggregation state.
Returns:
Any: Merged state, maintaining the same format and requirements as `create_state`.
"""
loss_sum1, weight_sum1 = state1
loss_sum2, weight_sum2 = state2
return loss_sum1 + loss_sum2, weight_sum1 + weight_sum2
Guidelines for Implementation
- State Management:
- Ensure the state is serializable and the serialized size of the state should be less than 8MB
- Edge Cases:
- Handle cases where inputs might be None.
- Ensure finish() handles empty states gracefully.
We will create one model to test the metric. Please follow the quickstart doc for details.
transactions_df = pd.read_csv("data/transactions.csv")
labels_df = pd.read_csv("data/labels.csv")
transactions_df = transactions_df.reset_index()
labels_df = labels_df.reset_index()
transactions = tb.PandasDataset(
dataset_name="transactions_custom_metric",
key_field="index",
dataframe=transactions_df,
upload=True,
)
labels = tb.PandasDataset(
dataset_name="transaction_labels_custom_metric",
key_field="index",
dataframe=labels_df,
upload=True,
)
model = tb.HoeffdingTreeClassifier(n_classes=2)
numerical_fields = [
"transactionAmount",
"localHour",
]
features = transactions.get_input_fields(numerical_fields=numerical_fields)
label = labels.get_label_field(label_field="is_fraud")
deployed_model_hft = model.deploy(name="demo_model_hft", input=features, labels=label)
outputs = deployed_model_hft.get_outputs()
We can register a metric and get evaluations
tb.register_custom_metric("FocalLoss", FocalLoss)
model_scores = deployed_model_hft.get_evaluation("FocalLoss")
model_scores[-1]
import matplotlib.pyplot as plt
plt.plot([model_score.metric for model_score in model_scores])