sega_learn.neural_networks.loss_jit

 
Modules
       
numpy

 
Classes
       
builtins.object
JITBCEWithLogitsLoss
JITCrossEntropyLoss
JITHuberLoss
JITMeanAbsoluteErrorLoss
JITMeanSquaredErrorLoss

 
class JITBCEWithLogitsLoss(builtins.object)
    Custom binary cross entropy loss with logits implementation using numba.
 
Formula: -mean(y * log(p) + (1 - y) * log(1 - p))
 
Methods:
    calculate_loss(self, logits, targets): Calculate the binary cross entropy loss.
 
  Methods defined here:
__init__(self)
Initializes the class with default values for logits and targets.
 
Attributes:
    logits (numpy.ndarray): A 2D array initialized to zeros with shape (1, 1),
                            representing the predicted values.
    targets (numpy.ndarray): A 2D array initialized to zeros with shape (1, 1),
                             representing the true target values.
calculate_loss(self, logits, targets)
Calculate the binary cross entropy loss.
 
Args:
    logits (np.ndarray): The logits (predicted values) of shape (num_samples,).
    targets (np.ndarray): The target labels of shape (num_samples,).
 
Returns:
    float: The binary cross entropy loss.

Data descriptors defined here:
__dict__
dictionary for instance variables
__weakref__
list of weak references to the object

 
class JITCrossEntropyLoss(builtins.object)
    Custom cross entropy loss implementation using numba for multi-class classification.
 
Formula: -sum(y * log(p) + (1 - y) * log(1 - p)) / m
Methods:
    calculate_loss(self, logits, targets): Calculate the cross entropy loss.
 
  Methods defined here:
__init__(self)
Initializes the instance variables for the class.
 
Args:
    logits: (np.ndarray) - A 2D array initialized to zeros with shape (1, 1),
               representing the predicted values or outputs of the model.
    targets: (np.ndarray) - A 2D array initialized to zeros with shape (1, 1),
                representing the ground truth or target values.
calculate_loss(self, logits, targets)
Calculate the cross entropy loss.
 
Args:
    logits (np.ndarray): The logits (predicted values) of shape (num_samples, num_classes).
    targets (np.ndarray): The target labels of shape (num_samples,).
 
Returns:
    float: The cross entropy loss.

Data descriptors defined here:
__dict__
dictionary for instance variables
__weakref__
list of weak references to the object

 
class JITHuberLoss(builtins.object)
    JITHuberLoss(delta=1.0)
 
Custom Huber loss implementation using numba.
 
Attributes:
    delta (float): The threshold parameter for Huber loss. Default is 1.0.
 
  Methods defined here:
__init__(self, delta=1.0)
Initializes the JITHuberLoss instance.
 
Args:
    delta (float): The threshold at which the loss function transitions
                   from quadratic to linear. Default is 1.0.
calculate_loss(self, y_pred, y_true)
Calculate the Huber loss using the stored delta.
 
Args:
    y_pred (np.ndarray): Predicted values.
    y_true (np.ndarray): True target values.
 
Returns:
    float: The calculated Huber loss.

Data descriptors defined here:
__dict__
dictionary for instance variables
__weakref__
list of weak references to the object

 
class JITMeanAbsoluteErrorLoss(builtins.object)
    Custom mean absolute error loss implementation using numba.
 
  Methods defined here:
calculate_loss(self, y_pred, y_true)
Calculate the mean absolute error loss.

Data descriptors defined here:
__dict__
dictionary for instance variables
__weakref__
list of weak references to the object

 
class JITMeanSquaredErrorLoss(builtins.object)
    Custom mean squared error loss implementation using numba.
 
  Methods defined here:
calculate_loss(self, y_pred, y_true)
Calculate the mean squared error loss.

Data descriptors defined here:
__dict__
dictionary for instance variables
__weakref__
list of weak references to the object

 
Functions
       
calculate_bce_with_logits_loss(logits, targets)
Helper function to calculate the binary cross entropy loss.
calculate_cross_entropy_loss(logits, targets)
Helper function to calculate the cross entropy loss.

 
Data
        CACHE = False