#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import string
from typing import Any, Dict, Optional, Union, List, Sequence, Mapping, Tuple
import uuid
import warnings
import pandas as pd
from pyspark.pandas.internal import InternalFrame
from pyspark.pandas.namespace import _get_index_map
from pyspark import pandas as ps
from pyspark.sql import SparkSession
from pyspark.pandas.utils import default_session
from pyspark.pandas.frame import DataFrame
from pyspark.pandas.series import Series
__all__ = ["sql"]
# This is not used in this file. It's for legacy sql_processor.
_CAPTURE_SCOPES = 3
[docs]def sql(
query: str,
index_col: Optional[Union[str, List[str]]] = None,
args: Optional[Union[Dict[str, Any], List]] = None,
**kwargs: Any,
) -> DataFrame:
"""
Execute a SQL query and return the result as a pandas-on-Spark DataFrame.
This function acts as a standard Python string formatter with understanding
the following variable types:
* pandas-on-Spark DataFrame
* pandas-on-Spark Series
* pandas DataFrame
* pandas Series
* string
Also the method can bind named parameters to SQL literals from `args`.
Parameters
----------
query : str
the SQL query
index_col : str or list of str, optional
Column names to be used in Spark to represent pandas-on-Spark's index. The index name
in pandas-on-Spark is ignored. By default, the index is always lost.
.. note:: If you want to preserve the index, explicitly use :func:`DataFrame.reset_index`,
and pass it to the SQL statement with `index_col` parameter.
For example,
>>> psdf = ps.DataFrame({"A": [1, 2, 3], "B":[4, 5, 6]}, index=['a', 'b', 'c'])
>>> new_psdf = psdf.reset_index()
>>> ps.sql("SELECT * FROM {new_psdf}", index_col="index", new_psdf=new_psdf)
... # doctest: +NORMALIZE_WHITESPACE
A B
index
a 1 4
b 2 5
c 3 6
For MultiIndex,
>>> psdf = ps.DataFrame(
... {"A": [1, 2, 3], "B": [4, 5, 6]},
... index=pd.MultiIndex.from_tuples(
... [("a", "b"), ("c", "d"), ("e", "f")], names=["index1", "index2"]
... ),
... )
>>> new_psdf = psdf.reset_index()
>>> ps.sql(
... "SELECT * FROM {new_psdf}", index_col=["index1", "index2"], new_psdf=new_psdf)
... # doctest: +NORMALIZE_WHITESPACE
A B
index1 index2
a b 1 4
c d 2 5
e f 3 6
Also note that the index name(s) should be matched to the existing name.
args : dict or list
A dictionary of parameter names to Python objects or a list of Python objects
that can be converted to SQL literal expressions. See
<a href="https://spark.apache.org/docs/latest/sql-ref-datatypes.html">
Supported Data Types</a> for supported value types in Python.
For example, dictionary keys: "rank", "name", "birthdate";
dictionary values: 1, "Steven", datetime.date(2023, 4, 2).
A value can be also a `Column` of literal expression, in that case it is taken as is.
.. versionadded:: 3.4.0
.. versionchanged:: 3.5.0
Added positional parameters.
kwargs
other variables that the user want to set that can be referenced in the query
Returns
-------
pandas-on-Spark DataFrame
Examples
--------
Calling a built-in SQL function.
>>> ps.sql("SELECT * FROM range(10) where id > 7")
id
0 8
1 9
>>> ps.sql("SELECT * FROM range(10) WHERE id > {bound1} AND id < {bound2}", bound1=7, bound2=9)
id
0 8
>>> mydf = ps.range(10)
>>> x = tuple(range(4))
>>> ps.sql("SELECT {ser} FROM {mydf} WHERE id IN {x}", ser=mydf.id, mydf=mydf, x=x)
id
0 0
1 1
2 2
3 3
Mixing pandas-on-Spark and pandas DataFrames in a join operation. Note that the index is
dropped.
>>> ps.sql('''
... SELECT m1.a, m2.b
... FROM {table1} m1 INNER JOIN {table2} m2
... ON m1.key = m2.key
... ORDER BY m1.a, m2.b''',
... table1=ps.DataFrame({"a": [1,2], "key": ["a", "b"]}),
... table2=pd.DataFrame({"b": [3,4,5], "key": ["a", "b", "b"]}))
a b
0 1 3
1 2 4
2 2 5
Also, it is possible to query using Series.
>>> psdf = ps.DataFrame({"A": [1, 2, 3], "B":[4, 5, 6]}, index=['a', 'b', 'c'])
>>> ps.sql("SELECT {mydf.A} FROM {mydf}", mydf=psdf)
A
0 1
1 2
2 3
And substitude named parameters with the `:` prefix by SQL literals.
>>> ps.sql("SELECT * FROM range(10) WHERE id > :bound1", args={"bound1":7})
id
0 8
1 9
Or positional parameters marked by `?` in the SQL query by SQL literals.
>>> ps.sql("SELECT * FROM range(10) WHERE id > ?", args=[7])
id
0 8
1 9
"""
if os.environ.get("PYSPARK_PANDAS_SQL_LEGACY") == "1":
from pyspark.pandas import sql_processor
warnings.warn(
"Deprecated in 3.3.0, and the legacy behavior "
"will be removed in the future releases.",
FutureWarning,
)
return sql_processor.sql(query, index_col=index_col, **kwargs)
session = default_session()
formatter = PandasSQLStringFormatter(session)
try:
sdf = session.sql(formatter.format(query, **kwargs), args)
finally:
formatter.clear()
index_spark_columns, index_names = _get_index_map(sdf, index_col)
return DataFrame(
InternalFrame(
spark_frame=sdf, index_spark_columns=index_spark_columns, index_names=index_names
)
)
class PandasSQLStringFormatter(string.Formatter):
"""
A standard ``string.Formatter`` in Python that can understand pandas-on-Spark instances
with basic Python objects. This object must be clear after the use for single SQL
query; cannot be reused across multiple SQL queries without cleaning.
"""
def __init__(self, session: SparkSession) -> None:
self._session: SparkSession = session
self._temp_views: List[Tuple[DataFrame, str]] = []
self._ref_sers: List[Tuple[Series, str]] = []
def vformat(self, format_string: str, args: Sequence[Any], kwargs: Mapping[str, Any]) -> str:
ret = super(PandasSQLStringFormatter, self).vformat(format_string, args, kwargs)
for ref, n in self._ref_sers:
if not any((ref is v for v in df._pssers.values()) for df, _ in self._temp_views):
# If referred DataFrame does not hold the given Series, raise an error.
raise ValueError("The series in {%s} does not refer any dataframe specified." % n)
return ret
def get_field(self, field_name: str, args: Sequence[Any], kwargs: Mapping[str, Any]) -> Any:
obj, first = super(PandasSQLStringFormatter, self).get_field(field_name, args, kwargs)
return self._convert_value(obj, field_name), first
def _convert_value(self, val: Any, name: str) -> Optional[str]:
"""
Converts the given value into a SQL string.
"""
if isinstance(val, pd.Series):
# Return the column name from pandas Series directly.
return ps.from_pandas(val).to_frame()._to_spark().columns[0]
elif isinstance(val, Series):
# Return the column name of pandas-on-Spark Series iff its DataFrame was
# referred. The check will be done in `vformat` after we parse all.
self._ref_sers.append((val, name))
return val.to_frame()._to_spark().columns[0]
elif isinstance(val, (DataFrame, pd.DataFrame)):
df_name = "_pandas_api_%s" % str(uuid.uuid4()).replace("-", "")
if isinstance(val, pd.DataFrame):
# Don't store temp view for plain pandas instances
# because it is unable to know which pandas DataFrame
# holds which Series.
val = ps.from_pandas(val)
else:
for df, n in self._temp_views:
if df is val:
return n
self._temp_views.append((val, df_name))
val._to_spark().createOrReplaceTempView(df_name)
return df_name
elif isinstance(val, str):
# This is matched to behavior from JVM implementation.
# See `sql` definition from `sql/catalyst/src/main/scala/org/apache/spark/
# sql/catalyst/expressions/literals.scala`
return "'" + val.replace("\\", "\\\\").replace("'", "\\'") + "'"
else:
return val
def clear(self) -> None:
for _, n in self._temp_views:
self._session.catalog.dropTempView(n)
self._temp_views = []
self._ref_sers = []
def _test() -> None:
import os
import doctest
import sys
from pyspark.sql import SparkSession
import pyspark.pandas.sql_formatter
os.chdir(os.environ["SPARK_HOME"])
globs = pyspark.pandas.sql_formatter.__dict__.copy()
globs["ps"] = pyspark.pandas
spark = (
SparkSession.builder.master("local[4]")
.appName("pyspark.pandas.sql_formatter tests")
.getOrCreate()
)
(failure_count, test_count) = doctest.testmod(
pyspark.pandas.sql_formatter,
globs=globs,
optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE,
)
spark.stop()
if failure_count:
sys.exit(-1)
if __name__ == "__main__":
_test()