Skip to content

Commit 3ded533

Browse files
committed
feat: add Radial Basis Function Neural Network implementation
Implements #12322 - Add RadialBasisFunctionNetwork class with train() and predict() methods - Use KMeans clustering for RBF center initialization - Implement Gaussian RBF activation functions - Use least-squares fitting for output weight calculation - Include comprehensive doctests and error handling - Add detailed docstrings with mathematical formulas
1 parent e2a78d4 commit 3ded533

File tree

1 file changed

+178
-0
lines changed

1 file changed

+178
-0
lines changed

machine_learning/rbf_network.py

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
"""
2+
Radial Basis Function Neural Network (RBFNN)
3+
4+
A neural network that uses radial basis functions (typically Gaussian) as activation
5+
functions in the hidden layer. RBFNNs are effective for function approximation and
6+
classification tasks.
7+
8+
Architecture:
9+
- Input Layer: Accepts n-dimensional input vectors
10+
- Hidden Layer: RBF neurons (Gaussian functions centered at data points)
11+
- Output Layer: Linear combination of hidden layer outputs
12+
13+
Reference: https://en.wikipedia.org/wiki/Radial_basis_function_network
14+
"""
15+
16+
import numpy as np
17+
from sklearn.cluster import KMeans
18+
19+
20+
class RadialBasisFunctionNetwork:
21+
"""
22+
Radial Basis Function Neural Network for regression and classification.
23+
24+
Uses KMeans clustering to determine RBF centers and least-squares
25+
fitting for output weights.
26+
27+
Attributes:
28+
num_centers: Number of RBF centers (hidden neurons)
29+
gamma: Spread parameter for Gaussian RBF (inverse of variance)
30+
centers: Cluster centers from KMeans
31+
weights: Output layer weights
32+
"""
33+
34+
def __init__(self, num_centers: int = 10, gamma: float = 1.0):
35+
"""
36+
Initialize RBFNN with specified parameters.
37+
38+
Args:
39+
num_centers: Number of RBF centers (default: 10)
40+
gamma: Gaussian spread parameter (default: 1.0)
41+
42+
>>> rbfnn = RadialBasisFunctionNetwork(num_centers=5, gamma=2.0)
43+
>>> rbfnn.num_centers
44+
5
45+
>>> rbfnn.gamma
46+
2.0
47+
"""
48+
if num_centers <= 0:
49+
raise ValueError("num_centers must be positive")
50+
if gamma <= 0:
51+
raise ValueError("gamma must be positive")
52+
53+
self.num_centers = num_centers
54+
self.gamma = gamma
55+
self.centers = None
56+
self.weights = None
57+
58+
def _gaussian_rbf(self, x: np.ndarray, center: np.ndarray) -> float:
59+
"""
60+
Compute Gaussian radial basis function.
61+
62+
RBF(x) = exp(-gamma * ||x - center||^2)
63+
64+
Args:
65+
x: Input vector
66+
center: RBF center vector
67+
68+
Returns:
69+
Activation value between 0 and 1
70+
"""
71+
distance_squared = np.sum((x - center) ** 2)
72+
return np.exp(-self.gamma * distance_squared)
73+
74+
def _compute_rbf_activations(self, X: np.ndarray) -> np.ndarray:
75+
"""
76+
Compute RBF activations for all input samples.
77+
78+
Args:
79+
X: Input data matrix (n_samples, n_features)
80+
81+
Returns:
82+
Activation matrix (n_samples, num_centers)
83+
"""
84+
n_samples = X.shape[0]
85+
activations = np.zeros((n_samples, self.num_centers))
86+
87+
for i in range(n_samples):
88+
for j in range(self.num_centers):
89+
activations[i, j] = self._gaussian_rbf(X[i], self.centers[j])
90+
91+
return activations
92+
93+
def train(self, X: np.ndarray, y: np.ndarray) -> None:
94+
"""
95+
Train the RBFNN using KMeans clustering and least-squares fitting.
96+
97+
Steps:
98+
1. Find RBF centers using KMeans clustering
99+
2. Compute RBF activations for all training samples
100+
3. Calculate output weights using least-squares fitting
101+
102+
Args:
103+
X: Training data (n_samples, n_features)
104+
y: Target values (n_samples,) or (n_samples, n_outputs)
105+
106+
>>> import numpy as np
107+
>>> np.random.seed(42)
108+
>>> X_train = np.random.randn(50, 2)
109+
>>> y_train = np.sum(X_train ** 2, axis=1)
110+
>>> rbfnn = RadialBasisFunctionNetwork(num_centers=5, gamma=1.0)
111+
>>> rbfnn.train(X_train, y_train)
112+
>>> rbfnn.centers.shape
113+
(5, 2)
114+
>>> rbfnn.weights.shape
115+
(5,)
116+
"""
117+
if X.shape[0] != len(y):
118+
raise ValueError("X and y must have the same number of samples")
119+
120+
if self.num_centers > X.shape[0]:
121+
raise ValueError("num_centers cannot exceed number of training samples")
122+
123+
# Step 1: Find RBF centers using KMeans clustering
124+
kmeans = KMeans(n_clusters=self.num_centers, random_state=42, n_init=10)
125+
kmeans.fit(X)
126+
self.centers = kmeans.cluster_centers_
127+
128+
# Step 2: Compute RBF activations
129+
activations = self._compute_rbf_activations(X)
130+
131+
# Step 3: Solve for output weights using least-squares
132+
# weights = (A^T A)^-1 A^T y, where A is the activation matrix
133+
self.weights = np.linalg.lstsq(activations, y, rcond=None)[0]
134+
135+
def predict(self, X: np.ndarray) -> np.ndarray:
136+
"""
137+
Make predictions using trained RBFNN.
138+
139+
Args:
140+
X: Input data (n_samples, n_features)
141+
142+
Returns:
143+
Predictions (n_samples,) or (n_samples, n_outputs)
144+
145+
>>> import numpy as np
146+
>>> np.random.seed(42)
147+
>>> X_train = np.array([[0, 0], [1, 1], [2, 2]])
148+
>>> y_train = np.array([0, 2, 4])
149+
>>> rbfnn = RadialBasisFunctionNetwork(num_centers=2, gamma=1.0)
150+
>>> rbfnn.train(X_train, y_train)
151+
>>> X_test = np.array([[0.5, 0.5], [1.5, 1.5]])
152+
>>> predictions = rbfnn.predict(X_test)
153+
>>> predictions.shape
154+
(2,)
155+
"""
156+
if self.centers is None or self.weights is None:
157+
raise RuntimeError("Model must be trained before making predictions")
158+
159+
if X.shape[1] != self.centers.shape[1]:
160+
msg = (
161+
f"Input dimension {X.shape[1]} does not match "
162+
f"training dimension {self.centers.shape[1]}"
163+
)
164+
raise ValueError(msg)
165+
166+
# Compute RBF activations for test data
167+
activations = self._compute_rbf_activations(X)
168+
169+
# Compute predictions as linear combination of activations
170+
predictions = activations @ self.weights
171+
172+
return predictions
173+
174+
175+
if __name__ == "__main__":
176+
import doctest
177+
178+
doctest.testmod()

0 commit comments

Comments
 (0)