# Copyright 1999-2021 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from functools import partial
try:
from sklearn.exceptions import DataConversionWarning
except ImportError: # pragma: no cover
DataConversionWarning = None
from ....tensor.spatial import distance
from ...utils.validation import check_non_negative
from .core import PairwiseDistances
from .manhattan import manhattan_distances
from .cosine import cosine_distances
from .euclidean import euclidean_distances
from .haversine import haversine_distances
_VALID_METRICS = [
"euclidean",
"l2",
"l1",
"manhattan",
"cityblock",
"braycurtis",
"canberra",
"chebyshev",
"correlation",
"cosine",
"dice",
"hamming",
"jaccard",
"kulsinski",
"mahalanobis",
"matching",
"minkowski",
"rogerstanimoto",
"russellrao",
"seuclidean",
"sokalmichener",
"sokalsneath",
"sqeuclidean",
"yule",
"wminkowski",
"haversine",
]
# Helper functions - distance
PAIRWISE_DISTANCE_FUNCTIONS = {
# If updating this dictionary, update the doc in both distance_metrics()
# and also in pairwise_distances()!
"cityblock": manhattan_distances,
"cosine": cosine_distances,
"euclidean": euclidean_distances,
"haversine": haversine_distances,
"l2": euclidean_distances,
"l1": manhattan_distances,
"manhattan": manhattan_distances,
"precomputed": None, # HACK: precomputed is always allowed, never called
}
# These distances recquire boolean tensors, when using mars.tensor.spatial.distance
PAIRWISE_BOOLEAN_FUNCTIONS = [
"dice",
"jaccard",
"kulsinski",
"matching",
"rogerstanimoto",
"russellrao",
"sokalmichener",
"sokalsneath",
"yule",
]
[docs]def pairwise_distances(X, Y=None, metric="euclidean", **kwds):
if (
metric not in _VALID_METRICS
and not callable(metric)
and metric != "precomputed"
):
raise ValueError(
f"Unknown metric {metric}. Valid metrics are {_VALID_METRICS}, "
"or 'precomputed', or a callable"
)
if metric == "precomputed":
X, _ = PairwiseDistances.check_pairwise_arrays(X, Y, precomputed=True)
whom = (
"`pairwise_distances`. Precomputed distance "
" need to have non-negative values."
)
X = check_non_negative(X, whom=whom)
return X
elif metric in PAIRWISE_DISTANCE_FUNCTIONS:
func = PAIRWISE_DISTANCE_FUNCTIONS[metric]
else:
# including when metric is callable
dtype = bool if metric in PAIRWISE_BOOLEAN_FUNCTIONS else None
if (
dtype == bool
and (X.dtype != bool or (Y is not None and Y.dtype != bool))
and DataConversionWarning is not None
):
msg = f"Data was converted to boolean for metric {metric}"
warnings.warn(msg, DataConversionWarning)
X, Y = PairwiseDistances.check_pairwise_arrays(X, Y, dtype=dtype)
if X is Y:
return distance.squareform(distance.pdist(X, metric=metric, **kwds))
func = partial(distance.cdist, metric=metric, **kwds)
return func(X, Y, **kwds)