Skip to content

Commit

Permalink
Merge pull request #36 from msamsami/stats-sub-package
Browse files Browse the repository at this point in the history
patch: Add a new sub-package `wnb.stats` to keep distribution-related modules
  • Loading branch information
msamsami authored Sep 22, 2024
2 parents 2e17f79 + 5621875 commit f2e6657
Show file tree
Hide file tree
Showing 18 changed files with 415 additions and 354 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

<div align="center">

![Lastest Release](https://img.shields.io/badge/release-v0.2.7-green)
![Lastest Release](https://img.shields.io/badge/release-v0.3.0-green)
[![PyPI Version](https://img.shields.io/pypi/v/wnb)](https://pypi.org/project/wnb/)
![Python Versions](https://img.shields.io/badge/python-3.8%20%7C%203.9%20%7C%203.10%20%7C%203.11%20%7C%203.12-blue)<br>
![GitHub Workflow Status (build)](https://github.com/msamsami/wnb/actions/workflows/python-publish.yml/badge.svg)
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ readme = "README.md"
readme-content-type = "text/markdown"
keywords = [
"python",
"machine learning",
"bayes",
"naive bayes",
"classifier",
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ long_description = file: README.md
long_description_content_type = text/markdown
license = BSD
license_file = LICENSE
keywords = python, bayes, naive bayes, classifier, probabilistic
keywords = python, machine learning, bayes, naive bayes, classifier, probabilistic
classifiers =
Intended Audience :: Science/Research
Intended Audience :: Developers
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
name="wnb",
version=__version__,
description="Python library for the implementations of general and weighted naive Bayes (WNB) classifiers.",
keywords=["python", "bayes", "naive bayes", "classifier", "probabilistic"],
keywords=["python", "machine learning", "bayes", "naive bayes", "classifier", "probabilistic"],
author="Mehdi Samsami",
author_email="mehdisamsami@live.com",
license="BSD License",
Expand Down
2 changes: 1 addition & 1 deletion tests/test_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from sklearn.utils._testing import assert_array_almost_equal

from wnb import Distribution as D
from wnb.dist import (
from wnb.stats import (
AllDistributions,
BernoulliDist,
BetaDist,
Expand Down
10 changes: 5 additions & 5 deletions wnb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@
Python library for the implementations of general and weighted naive Bayes (WNB) classifiers.
"""

__version__ = "0.2.7"
__version__ = "0.3.0"
__author__ = "Mehdi Samsami"


from .base import ContinuousDistMixin, DiscreteDistMixin
from .enums import Distribution
from .gnb import GeneralNB
from .gwnb import GaussianWNB
from wnb.gnb import GeneralNB
from wnb.gwnb import GaussianWNB
from wnb.stats.base import ContinuousDistMixin, DiscreteDistMixin
from wnb.stats.enums import Distribution

__all__ = [
"GeneralNB",
Expand Down
24 changes: 0 additions & 24 deletions wnb/_typing.py

This file was deleted.

298 changes: 9 additions & 289 deletions wnb/dist.py
Original file line number Diff line number Diff line change
@@ -1,291 +1,11 @@
from typing import Any, Mapping
import warnings

import numpy as np
from scipy.special import beta, gamma
from scipy.stats import chi2
from wnb.stats import *

from .base import ContinuousDistMixin, DiscreteDistMixin
from .enums import Distribution as D

__all__ = [
"NormalDist",
"LognormalDist",
"ExponentialDist",
"UniformDist",
"ParetoDist",
"GammaDist",
"BetaDist",
"ChiSquaredDist",
"TDist",
"RayleighDist",
"BernoulliDist",
"CategoricalDist",
"GeometricDist",
"PoissonDist",
]


class NormalDist(ContinuousDistMixin):
name = D.NORMAL
_support = (-np.inf, np.inf)

def __init__(self, mu: float, sigma: float):
self.mu = mu
self.sigma = sigma
super().__init__()

@classmethod
def from_data(cls, data: np.ndarray, **kwargs):
return cls(mu=np.average(data), sigma=np.std(data))

def pdf(self, x: float) -> float:
return (1.0 / np.sqrt(2 * np.pi * self.sigma**2)) * np.exp(-0.5 * (((x - self.mu) / self.sigma) ** 2))


class LognormalDist(ContinuousDistMixin):
name = D.LOGNORMAL
_support = (0, np.inf)

def __init__(self, mu: float, sigma: float):
self.mu = mu
self.sigma = sigma
super().__init__()

@classmethod
def from_data(cls, data: np.ndarray, **kwargs):
log_data = np.log(data)
return cls(mu=np.average(log_data), sigma=np.std(log_data))

def pdf(self, x: float) -> float:
return (1.0 / (x * self.sigma * np.sqrt(2 * np.pi))) * np.exp(
-0.5 * ((np.log(x) - self.mu) / self.sigma) ** 2
)


class ExponentialDist(ContinuousDistMixin):
name = D.EXPONENTIAL
_support = (0, np.inf)

def __init__(self, rate: float):
self.rate = rate
super().__init__()

@classmethod
def from_data(cls, data: np.ndarray, **kwargs):
return cls(rate=(len(data) - 2) / np.sum(data))

def pdf(self, x: float) -> float:
return self.rate * np.exp(-self.rate * x) if x >= 0 else 0.0


class UniformDist(ContinuousDistMixin):
name = D.UNIFORM
_support = None

def __init__(self, a: float, b: float):
self.a = a
self.b = b
self._support = (a, b)
super().__init__()

@classmethod
def from_data(cls, data, **kwargs):
return cls(a=np.min(data), b=np.max(data))

def pdf(self, x: float) -> float:
return 1 / (self.b - self.a) if self.a <= x <= self.b else 0.0


class ParetoDist(ContinuousDistMixin):
name = D.PARETO
_support = None

def __init__(self, x_m: float, alpha: float):
self.x_m = x_m
self.alpha = alpha
self._support = (self.x_m, np.inf)
super().__init__()

@classmethod
def from_data(cls, data, **kwargs):
x_m = np.min(data)
return cls(x_m=x_m, alpha=len(data) / np.sum(np.log(data / x_m)))

def pdf(self, x: float) -> float:
return (self.alpha * self.x_m**self.alpha) / x ** (self.alpha + 1) if x >= self.x_m else 0.0


class GammaDist(ContinuousDistMixin):
name = D.GAMMA
_support = (0, np.inf)

def __init__(self, k: float, theta: float):
self.k = k
self.theta = theta
super().__init__()

@classmethod
def from_data(cls, data, **kwargs):
n = len(data)
return cls(
k=n * np.sum(data) / (n * np.sum(data * np.log(data)) - np.sum(data * np.sum(np.log(data)))),
theta=(n * np.sum(data * np.log(data)) - np.sum(data * np.sum(np.log(data)))) / n**2,
)

def pdf(self, x: float) -> float:
return (x ** (self.k - 1) * np.exp(-x / self.theta)) / (gamma(self.k) * self.theta**self.k)


class BetaDist(ContinuousDistMixin):
name = D.BETA
_support = (0, 1)

def __init__(self, alpha: float, beta: float):
self.alpha = alpha
self.beta = beta
super().__init__()

@classmethod
def from_data(cls, data, **kwargs):
mu_hat = np.average(data)
var_hat = np.var(data, ddof=1)
multiplied_term = (mu_hat * (1 - mu_hat) / var_hat) - 1
return cls(
alpha=mu_hat * multiplied_term,
beta=(1 - mu_hat) * multiplied_term,
)

def pdf(self, x: float) -> float:
return ((x ** (self.alpha - 1)) * (1 - x) ** (self.beta - 1)) / beta(self.alpha, self.beta)


class ChiSquaredDist(ContinuousDistMixin):
name = D.CHI_SQUARED
_support = (0, np.inf)

def __init__(self, k: int):
self.k = k
super().__init__()

@classmethod
def from_data(cls, data, **kwargs):
return cls(k=round(np.average(data)))

def pdf(self, x: float) -> float:
return chi2.pdf(x, self.k)


class TDist(ContinuousDistMixin):
name = D.T
_support = (-np.inf, np.inf)

def __init__(self, df: float):
self.df = df
super().__init__()

@classmethod
def from_data(cls, data, **kwargs):
return cls(df=len(data) - 1)

def pdf(self, x: float) -> float:
return (gamma((self.df + 1) / 2) / (np.sqrt(self.df * np.pi) * gamma(self.df / 2))) * (
1 + (x**2 / self.df)
) ** (-(self.df + 1) / 2)


class RayleighDist(ContinuousDistMixin):
name = D.RAYLEIGH
_support = (0, np.inf)

def __init__(self, sigma: float):
self.sigma = sigma
super().__init__()

@classmethod
def from_data(cls, data, **kwargs):
sigma = np.sqrt(np.mean(data**2) / 2)
return cls(sigma=sigma)

def pdf(self, x: float) -> float:
return (x / self.sigma**2) * np.exp(-(x**2) / (2 * self.sigma**2)) if x >= 0 else 0.0


class BernoulliDist(DiscreteDistMixin):
name = D.BERNOULLI
_support = [0, 1]

def __init__(self, p: float):
self.p = p
super().__init__()

@classmethod
def from_data(cls, data, **kwargs):
alpha = kwargs.get("alpha", 1e-10)
return cls(p=((np.array(data) == 1).sum() + alpha) / len(data))

def pmf(self, x: int) -> float:
if x not in self._support:
return 0.0
else:
return self.p if x == 1 else 1 - self.p


class CategoricalDist(DiscreteDistMixin):
name = D.CATEGORICAL
_support = None

def __init__(self, prob: Mapping[Any, float]):
self.prob = prob
self._support = list(self.prob.keys())
super().__init__()

@classmethod
def from_data(cls, data, **kwargs):
alpha = kwargs.get("alpha", 1e-10)
values, counts = np.unique(data, return_counts=True)
return cls(prob={v: (c + alpha) / len(data) for v, c in zip(values, counts)})

def pmf(self, x: Any) -> float:
return self.prob.get(x, 0.0)


class GeometricDist(DiscreteDistMixin):
name = D.GEOMETRIC
_support = (1, np.inf)

def __init__(self, p: float):
self.p = p
super().__init__()

@classmethod
def from_data(cls, data, **kwargs):
return cls(p=len(data) / np.sum(data))

def pmf(self, x: int) -> float:
return self.p * (1 - self.p) ** (x - 1) if x >= self._support[0] and x - int(x) == 0 else 0.0


class PoissonDist(DiscreteDistMixin):
name = D.POISSON
_support = (0, np.inf)

def __init__(self, rate: float):
self.rate = rate
super().__init__()

@classmethod
def from_data(cls, data, **kwargs):
return cls(rate=np.sum(data) / len(data))

def pmf(self, x: int) -> float:
return (
(np.exp(-self.rate) * self.rate**x) / np.math.factorial(x)
if x >= self._support[0] and x - int(x) == 0
else 0.0
)


AllDistributions = {cls.name: cls for cls in (globals()[name] for name in __all__)}


NonNumericDistributions = [D.CATEGORICAL]
warnings.warn(
"The `wnb.dist` module is deprecated and will be removed in a future release. "
"Please update your imports to use `wnb` or `wnb.stats` directly. "
"Using `wnb.dist` will continue to work in this version, but it may be removed in future versions.",
DeprecationWarning,
stacklevel=2,
)
Loading

0 comments on commit f2e6657

Please sign in to comment.