Source code for libuplift.utils.validation

"""Utilities for input validation."""

import numpy as np

from sklearn.utils import column_or_1d, check_consistent_length

[docs] def check_trt(trt, n_trt=None, y=None): trt = column_or_1d(trt) if not np.issubdtype(trt.dtype, np.integer): raise ValueError("Treatment values must be integers") if (trt < 0).any(): raise ValueError("Treatment values must be >= 0") if n_trt is not None: if np.max(trt) > n_trt: raise ValueError("Treatment values must be <= n_trt") else: n_trt = np.max(trt) return trt, n_trt