Commit a041ca77 authored by Julien Jerphanion's avatar Julien Jerphanion

Add python front-end

Raw source code taken from:
https://github.com/scikit-learn/scikit-learn/blob/579e7de7f38f9f514ff2b2be049e67b14e723d17/sklearn/cluster/_kmeans.py#L525

Not functional, just for sake of reference
parent ad0c7e81
def _kmeans_single_lloyd(X, sample_weight, centers_init, max_iter=300,
verbose=False, x_squared_norms=None, tol=1e-4,
n_threads=1):
"""A single run of k-means lloyd, assumes preparation completed prior.
Parameters
----------
X : {ndarray, sparse matrix} of shape (n_samples, n_features)
The observations to cluster. If sparse matrix, must be in CSR format.
sample_weight : ndarray of shape (n_samples,)
The weights for each observation in X.
centers_init : ndarray of shape (n_clusters, n_features)
The initial centers.
max_iter : int, default=300
Maximum number of iterations of the k-means algorithm to run.
verbose : bool, default=False
Verbosity mode
x_squared_norms : ndarray of shape (n_samples,), default=None
Precomputed x_squared_norms.
tol : float, default=1e-4
Relative tolerance with regards to Frobenius norm of the difference
in the cluster centers of two consecutive iterations to declare
convergence.
It's not advised to set `tol=0` since convergence might never be
declared due to rounding errors. Use a very small number instead.
n_threads : int, default=1
The number of OpenMP threads to use for the computation. Parallelism is
sample-wise on the main cython loop which assigns each sample to its
closest center.
Returns
-------
centroid : ndarray of shape (n_clusters, n_features)
Centroids found at the last iteration of k-means.
label : ndarray of shape (n_samples,)
label[i] is the code or index of the centroid the
i'th observation is closest to.
inertia : float
The final value of the inertia criterion (sum of squared distances to
the closest centroid for all observations in the training set).
n_iter : int
Number of iterations run.
"""
n_clusters = centers_init.shape[0]
# Buffers to avoid new allocations at each iteration.
centers = centers_init
centers_new = np.zeros_like(centers)
labels = np.full(X.shape[0], -1, dtype=np.int32)
labels_old = labels.copy()
weight_in_clusters = np.zeros(n_clusters, dtype=X.dtype)
center_shift = np.zeros(n_clusters, dtype=X.dtype)
if sp.issparse(X):
lloyd_iter = lloyd_iter_chunked_sparse
_inertia = _inertia_sparse
else:
lloyd_iter = lloyd_iter_chunked_dense
_inertia = _inertia_dense
strict_convergence = False
# Threadpoolctl context to limit the number of threads in second level of
# nested parallelism (i.e. BLAS) to avoid oversubsciption.
with threadpool_limits(limits=1, user_api="blas"):
for i in range(max_iter):
lloyd_iter(X, sample_weight, x_squared_norms, centers, centers_new,
weight_in_clusters, labels, center_shift, n_threads)
if verbose:
inertia = _inertia(X, sample_weight, centers, labels)
print(f"Iteration {i}, inertia {inertia}.")
centers, centers_new = centers_new, centers
if np.array_equal(labels, labels_old):
# First check the labels for strict convergence.
if verbose:
print(f"Converged at iteration {i}: strict convergence.")
strict_convergence = True
break
else:
# No strict convergence, check for tol based convergence.
center_shift_tot = (center_shift**2).sum()
if center_shift_tot <= tol:
if verbose:
print(f"Converged at iteration {i}: center shift "
f"{center_shift_tot} within tolerance {tol}.")
break
labels_old[:] = labels
if not strict_convergence:
# rerun E-step so that predicted labels match cluster centers
lloyd_iter(X, sample_weight, x_squared_norms, centers, centers,
weight_in_clusters, labels, center_shift, n_threads,
update_centers=False)
inertia = _inertia(X, sample_weight, centers, labels)
return labels, inertia, centers, i + 1
\ No newline at end of file
...@@ -8,7 +8,7 @@ from Cython.Build import build_ext ...@@ -8,7 +8,7 @@ from Cython.Build import build_ext
# #
extensions = [ extensions = [
Extension("kmeans", Extension("_kmeans",
sources=["_kmeans.pyx"], sources=["_kmeans.pyx"],
include_dirs=[numpy.get_include()], include_dirs=[numpy.get_include()],
define_macros=[("NPY_NO_DEPRECATED_API", "NPY_1_7_API_VERSION")], define_macros=[("NPY_NO_DEPRECATED_API", "NPY_1_7_API_VERSION")],
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment