Source code for dbsnaplake.polars_utils

# -*- coding: utf-8 -*-

"""

"""

import typing as T
import io

import polars as pl
from s3pathlib import S3Path

from .typehint import T_OPTIONAL_KWARGS
from .constants import S3_METADATA_KEY_N_RECORD
from .partition import encode_hive_partition

if T.TYPE_CHECKING:  # pragma: no cover
    from mypy_boto3_s3.client import S3Client


[docs]def write_parquet_to_s3( df: pl.DataFrame, s3path: S3Path, s3_client: "S3Client", polars_write_parquet_kwargs: T.Optional[T.Dict[str, T.Any]] = None, s3pathlib_write_bytes_kwargs: T.Optional[T.Dict[str, T.Any]] = None, ) -> T.Tuple[int, int, str]: """ Write polars dataframe to AWS S3 as a parquet file. The original ``polars.write_parquet`` method doesn't work with moto. :param df: ``polars.DataFrame`` object. :param s3path: ``s3pathlib.S3Path`` object. :param s3_client: ``boto3.client("s3")`` object. :param polars_write_parquet_kwargs: Keyword arguments for ``polars.DataFrame.write_parquet`` method. See https://docs.pola.rs/api/python/stable/reference/api/polars.DataFrame.write_parquet.html :param s3pathlib_write_bytes_kwargs: Keyword arguments for ``s3path.write_bytes`` method. See https://s3pathlib.readthedocs.io/en/latest/s3pathlib/core/rw.html#s3pathlib.core.rw.ReadAndWriteAPIMixin.write_bytes :return: A tuple of three values: - The number of bytes written to S3, i.e., the size of the parquet file. - The number of records in the DataFrame. - The ETag of the S3 object. """ if polars_write_parquet_kwargs is None: polars_write_parquet_kwargs = {} buffer = io.BytesIO() df.write_parquet(buffer, **polars_write_parquet_kwargs) if s3pathlib_write_bytes_kwargs is None: s3pathlib_write_bytes_kwargs = {} s3pathlib_write_bytes_kwargs["content_type"] = "application/x-parquet" b = buffer.getvalue() s3path_new = s3path.write_bytes(b, bsm=s3_client, **s3pathlib_write_bytes_kwargs) size = len(b) n_record = df.shape[0] etag = s3path_new.etag return (size, n_record, etag)
[docs]def write_data_file( df: pl.DataFrame, s3path: S3Path, s3_client: "S3Client", polars_write_parquet_kwargs: T_OPTIONAL_KWARGS = None, s3pathlib_write_bytes_kwargs: T_OPTIONAL_KWARGS = None, ) -> T.Tuple[int, int, str]: """ Write the DataFrame to the given S3Path as a Parquet file, also attach additional information related to the Snapshot Data File. It is a wrapper of the write_parquet_to_s3 function, make the final code shorter. :return: A tuple of three values: - The number of bytes written to S3, i.e., the size of the parquet file. - The number of records in the DataFrame. - The ETag of the S3 object. """ if s3pathlib_write_bytes_kwargs is None: s3pathlib_write_bytes_kwargs = {} more_metadata = { S3_METADATA_KEY_N_RECORD: str(df.shape[0]), } if "metadata" in s3pathlib_write_bytes_kwargs: s3pathlib_write_bytes_kwargs["metadata"].update(more_metadata) else: s3pathlib_write_bytes_kwargs["metadata"] = more_metadata return write_parquet_to_s3( df=df, s3path=s3path, s3_client=s3_client, polars_write_parquet_kwargs=polars_write_parquet_kwargs, s3pathlib_write_bytes_kwargs=s3pathlib_write_bytes_kwargs, )
def read_parquet_from_s3( s3path: S3Path, s3_client: "S3Client", polars_read_parquet_kwargs: T_OPTIONAL_KWARGS = None, s3pathlib_read_bytes_kwargs: T_OPTIONAL_KWARGS = None, ) -> pl.DataFrame: if polars_read_parquet_kwargs is None: polars_read_parquet_kwargs = {} if s3pathlib_read_bytes_kwargs is None: s3pathlib_read_bytes_kwargs = {} b = s3path.read_bytes(bsm=s3_client, **s3pathlib_read_bytes_kwargs) df = pl.read_parquet(b, **polars_read_parquet_kwargs) return df def read_many_parquet_from_s3( s3path_list: T.List[S3Path], s3_client: "S3Client", polars_read_parquet_kwargs: T_OPTIONAL_KWARGS = None, s3pathlib_read_bytes_kwargs: T_OPTIONAL_KWARGS = None, ) -> pl.DataFrame: sub_df_list = list() for s3path in s3path_list: sub_df = read_parquet_from_s3( s3path=s3path, s3_client=s3_client, polars_read_parquet_kwargs=polars_read_parquet_kwargs, s3pathlib_read_bytes_kwargs=s3pathlib_read_bytes_kwargs, ) sub_df_list.append(sub_df) df = pl.concat(sub_df_list) return df
[docs]def group_by_partition( df: pl.DataFrame, s3dir: S3Path, filename: str, partition_keys: T.List[str], sort_by: T.List[str], ) -> T.List[T.Tuple[pl.DataFrame, S3Path]]: """ Group dataframe by partition keys and locate the S3 location for each partition. :param df: ``polars.DataFrame`` object. :param s3dir: ``s3pathlib.S3Path`` object, the root directory of the S3 location. :param filename: filename of the parquet file. for example: "data.parquet". :param partition_keys: list of partition keys. for example: ["year", "month"]. :param sort_by: list of columns to sort by. for example: ["create_time"]. use empty list if no sorting is needed. """ results = list() partition_values: T.List[str] for ith, (partition_values, sub_df) in enumerate( df.group_by(partition_keys), start=1, ): sub_df = sub_df.drop(partition_keys) if sort_by: sub_df = sub_df.sort(by=sort_by) kvs = dict(zip(partition_keys, partition_values)) partition_relpath = encode_hive_partition(kvs=kvs) s3path = s3dir.joinpath(partition_relpath, filename) results.append((sub_df, s3path)) return results