Source code for pyspark.sql.utils

#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import inspect
import functools
import os
from typing import (
    Any,
    Callable,
    Optional,
    List,
    Sequence,
    TYPE_CHECKING,
    cast,
    TypeVar,
    Union,
    Type,
)

# For backward compatibility.
from pyspark.errors import (  # noqa: F401
    AnalysisException,
    ParseException,
    IllegalArgumentException,
    StreamingQueryException,
    QueryExecutionException,
    PythonException,
    UnknownException,
    SparkUpgradeException,
    PySparkNotImplementedError,
    PySparkRuntimeError,
)
from pyspark.util import is_remote_only, JVM_INT_MAX
from pyspark.errors.exceptions.captured import CapturedException  # noqa: F401
from pyspark.find_spark_home import _find_spark_home

if TYPE_CHECKING:
    from py4j.java_collections import JavaArray
    from py4j.java_gateway import (
        JavaClass,
        JavaGateway,
        JavaObject,
    )
    from pyspark import SparkContext
    from pyspark.sql.session import SparkSession
    from pyspark.sql.dataframe import DataFrame
    from pyspark.sql.column import Column
    from pyspark.sql.window import Window
    from pyspark.pandas._typing import IndexOpsLike, SeriesOrIndex

has_numpy: bool = False
try:
    import numpy as np  # noqa: F401

    has_numpy = True
except ImportError:
    pass


FuncT = TypeVar("FuncT", bound=Callable[..., Any])


def toJArray(gateway: "JavaGateway", jtype: "JavaClass", arr: Sequence[Any]) -> "JavaArray":
    """
    Convert python list to java type array

    Parameters
    ----------
    gateway :
        Py4j Gateway
    jtype :
        java type of element in array
    arr :
        python type list
    """
    jarray: "JavaArray" = gateway.new_array(jtype, len(arr))
    for i in range(0, len(arr)):
        jarray[i] = arr[i]
    return jarray


def require_test_compiled() -> None:
    """Raise Exception if test classes are not compiled"""
    import os
    import glob

    test_class_path = os.path.join(_find_spark_home(), "sql", "core", "target", "*", "test-classes")
    paths = glob.glob(test_class_path)

    if len(paths) == 0:
        raise PySparkRuntimeError(
            error_class="TEST_CLASS_NOT_COMPILED",
            message_parameters={"test_class_path": test_class_path},
        )


class ForeachBatchFunction:
    """
    This is the Python implementation of Java interface 'ForeachBatchFunction'. This wraps
    the user-defined 'foreachBatch' function such that it can be called from the JVM when
    the query is active.
    """

    def __init__(self, session: "SparkSession", func: Callable[["DataFrame", int], None]):
        self.func = func
        self.session = session

    def call(self, jdf: "JavaObject", batch_id: int) -> None:
        from pyspark.sql.dataframe import DataFrame
        from pyspark.sql.session import SparkSession

        try:
            session_jdf = jdf.sparkSession()
            # assuming that spark context is still the same between JVM and PySpark
            wrapped_session_jdf = SparkSession(self.session.sparkContext, session_jdf)
            self.func(DataFrame(jdf, wrapped_session_jdf), batch_id)
        except Exception as e:
            self.error = e
            raise e

    class Java:
        implements = ["org.apache.spark.sql.execution.streaming.sources.PythonForeachBatchFunction"]


# Python implementation of 'org.apache.spark.sql.catalyst.util.StringConcat'
class StringConcat:
    def __init__(self, maxLength: int = JVM_INT_MAX - 15):
        self.maxLength: int = maxLength
        self.strings: List[str] = []
        self.length: int = 0

    def atLimit(self) -> bool:
        return self.length >= self.maxLength

    def append(self, s: str) -> None:
        if s is not None:
            sLen = len(s)
            if not self.atLimit():
                available = self.maxLength - self.length
                stringToAppend = s if available >= sLen else s[0:available]
                self.strings.append(stringToAppend)

            self.length = min(self.length + sLen, JVM_INT_MAX - 15)

    def toString(self) -> str:
        # finalLength = self.maxLength if self.atLimit()  else self.length
        return "".join(self.strings)


# Python implementation of 'org.apache.spark.util.SparkSchemaUtils.escapeMetaCharacters'
def escape_meta_characters(s: str) -> str:
    return (
        s.replace("\n", "\\n")
        .replace("\r", "\\r")
        .replace("\t", "\\t")
        .replace("\f", "\\f")
        .replace("\b", "\\b")
        .replace("\u000B", "\\v")
        .replace("\u0007", "\\a")
    )


def to_str(value: Any) -> Optional[str]:
    """
    A wrapper over str(), but converts bool values to lower case strings.
    If None is given, just returns None, instead of converting it to string "None".
    """
    if isinstance(value, bool):
        return str(value).lower()
    elif value is None:
        return value
    else:
        return str(value)


def is_timestamp_ntz_preferred() -> bool:
    """
    Return a bool if TimestampNTZType is preferred according to the SQL configuration set.
    """
    if is_remote():
        from pyspark.sql.connect.session import SparkSession as ConnectSparkSession

        session = ConnectSparkSession.getActiveSession()
        if session is None:
            return False
        else:
            return session.conf.get("spark.sql.timestampType", None) == "TIMESTAMP_NTZ"
    else:
        from pyspark import SparkContext

        jvm = SparkContext._jvm
        return jvm is not None and jvm.PythonSQLUtils.isTimestampNTZPreferred()


[docs]def is_remote() -> bool: """ Returns if the current running environment is for Spark Connect. .. versionadded:: 4.0.0 Notes ----- This will only return ``True`` if there is a remote session running. Otherwise, it returns ``False``. This API is unstable, and for developers. Returns ------- bool Examples -------- >>> from pyspark.sql import is_remote >>> is_remote() False """ return ("SPARK_CONNECT_MODE_ENABLED" in os.environ) or is_remote_only()
def try_remote_functions(f: FuncT) -> FuncT: """Mark API supported from Spark Connect.""" @functools.wraps(f) def wrapped(*args: Any, **kwargs: Any) -> Any: if is_remote() and "PYSPARK_NO_NAMESPACE_SHARE" not in os.environ: from pyspark.sql.connect import functions return getattr(functions, f.__name__)(*args, **kwargs) else: return f(*args, **kwargs) return cast(FuncT, wrapped) def try_partitioning_remote_functions(f: FuncT) -> FuncT: """Mark API supported from Spark Connect.""" @functools.wraps(f) def wrapped(*args: Any, **kwargs: Any) -> Any: if is_remote() and "PYSPARK_NO_NAMESPACE_SHARE" not in os.environ: from pyspark.sql.connect.functions import partitioning return getattr(partitioning, f.__name__)(*args, **kwargs) else: return f(*args, **kwargs) return cast(FuncT, wrapped) def try_remote_avro_functions(f: FuncT) -> FuncT: """Mark API supported from Spark Connect.""" @functools.wraps(f) def wrapped(*args: Any, **kwargs: Any) -> Any: if is_remote() and "PYSPARK_NO_NAMESPACE_SHARE" not in os.environ: from pyspark.sql.connect.avro import functions return getattr(functions, f.__name__)(*args, **kwargs) else: return f(*args, **kwargs) return cast(FuncT, wrapped) def try_remote_protobuf_functions(f: FuncT) -> FuncT: """Mark API supported from Spark Connect.""" @functools.wraps(f) def wrapped(*args: Any, **kwargs: Any) -> Any: if is_remote() and "PYSPARK_NO_NAMESPACE_SHARE" not in os.environ: from pyspark.sql.connect.protobuf import functions return getattr(functions, f.__name__)(*args, **kwargs) else: return f(*args, **kwargs) return cast(FuncT, wrapped) def try_remote_window(f: FuncT) -> FuncT: """Mark API supported from Spark Connect.""" @functools.wraps(f) def wrapped(*args: Any, **kwargs: Any) -> Any: if is_remote() and "PYSPARK_NO_NAMESPACE_SHARE" not in os.environ: from pyspark.sql.connect.window import Window return getattr(Window, f.__name__)(*args, **kwargs) else: return f(*args, **kwargs) return cast(FuncT, wrapped) def try_remote_windowspec(f: FuncT) -> FuncT: """Mark API supported from Spark Connect.""" @functools.wraps(f) def wrapped(*args: Any, **kwargs: Any) -> Any: if is_remote() and "PYSPARK_NO_NAMESPACE_SHARE" not in os.environ: from pyspark.sql.connect.window import WindowSpec return getattr(WindowSpec, f.__name__)(*args, **kwargs) else: return f(*args, **kwargs) return cast(FuncT, wrapped) def get_active_spark_context() -> "SparkContext": """Raise RuntimeError if SparkContext is not initialized, otherwise, returns the active SparkContext.""" from pyspark import SparkContext sc = SparkContext._active_spark_context if sc is None or sc._jvm is None: raise PySparkRuntimeError( error_class="SESSION_OR_CONTEXT_NOT_EXISTS", message_parameters={}, ) return sc def try_remote_session_classmethod(f: FuncT) -> FuncT: """Mark API supported from Spark Connect.""" @functools.wraps(f) def wrapped(*args: Any, **kwargs: Any) -> Any: if is_remote() and "PYSPARK_NO_NAMESPACE_SHARE" not in os.environ: from pyspark.sql.connect.session import SparkSession assert inspect.isclass(args[0]) return getattr(SparkSession, f.__name__)(*args[1:], **kwargs) else: return f(*args, **kwargs) return cast(FuncT, wrapped) def dispatch_df_method(f: FuncT) -> FuncT: """ For the usecases of direct DataFrame.union(df, ...), it checks if self is a Connect DataFrame or Classic DataFrame, and dispatches. """ @functools.wraps(f) def wrapped(*args: Any, **kwargs: Any) -> Any: if is_remote() and "PYSPARK_NO_NAMESPACE_SHARE" not in os.environ: from pyspark.sql.connect.dataframe import DataFrame as ConnectDataFrame if isinstance(args[0], ConnectDataFrame): return getattr(ConnectDataFrame, f.__name__)(*args, **kwargs) else: from pyspark.sql.classic.dataframe import DataFrame as ClassicDataFrame if isinstance(args[0], ClassicDataFrame): return getattr(ClassicDataFrame, f.__name__)(*args, **kwargs) raise PySparkNotImplementedError( error_class="NOT_IMPLEMENTED", message_parameters={"feature": f"DataFrame.{f.__name__}"}, ) return cast(FuncT, wrapped) def dispatch_col_method(f: FuncT) -> FuncT: """ For the usecases of direct Column.method(col, ...), it checks if self is a Connect DataFrame or Classic DataFrame, and dispatches. """ @functools.wraps(f) def wrapped(*args: Any, **kwargs: Any) -> Any: if is_remote() and "PYSPARK_NO_NAMESPACE_SHARE" not in os.environ: from pyspark.sql.connect.column import Column as ConnectColumn if isinstance(args[0], ConnectColumn): return getattr(ConnectColumn, f.__name__)(*args, **kwargs) else: from pyspark.sql.classic.column import Column as ClassicColumn if isinstance(args[0], ClassicColumn): return getattr(ClassicColumn, f.__name__)(*args, **kwargs) raise PySparkNotImplementedError( error_class="NOT_IMPLEMENTED", message_parameters={"feature": f"DataFrame.{f.__name__}"}, ) return cast(FuncT, wrapped) def pyspark_column_op( func_name: str, left: "IndexOpsLike", right: Any, fillna: Any = None ) -> Union["SeriesOrIndex", None]: """ Wrapper function for column_op to get proper Column class. """ from pyspark.pandas.base import column_op from pyspark.sql.column import Column as PySparkColumn from pyspark.pandas.data_type_ops.base import _is_extension_dtypes if is_remote(): from pyspark.sql.connect.column import Column as ConnectColumn Column = ConnectColumn else: Column = PySparkColumn # type: ignore[assignment] result = column_op(getattr(Column, func_name))(left, right) # It works as expected on extension dtype, so we don't need to call `fillna` for this case. if (fillna is not None) and (_is_extension_dtypes(left) or _is_extension_dtypes(right)): fillna = None # TODO(SPARK-43877): Fix behavior difference for compare binary functions. return result.fillna(fillna) if fillna is not None else result def get_column_class() -> Type["Column"]: from pyspark.sql.column import Column as PySparkColumn if is_remote(): from pyspark.sql.connect.column import Column as ConnectColumn return ConnectColumn else: return PySparkColumn def get_dataframe_class() -> Type["DataFrame"]: from pyspark.sql.dataframe import DataFrame as PySparkDataFrame if is_remote(): from pyspark.sql.connect.dataframe import DataFrame as ConnectDataFrame return ConnectDataFrame else: return PySparkDataFrame def get_window_class() -> Type["Window"]: from pyspark.sql.window import Window as PySparkWindow if is_remote(): from pyspark.sql.connect.window import Window as ConnectWindow return ConnectWindow # type: ignore[return-value] else: return PySparkWindow def get_lit_sql_str(val: str) -> str: # Equivalent to `lit(val)._jc.expr().sql()` for string typed val # See `sql` definition in `sql/catalyst/src/main/scala/org/apache/spark/ # sql/catalyst/expressions/literals.scala` return "'" + val.replace("\\", "\\\\").replace("'", "\\'") + "'"