在 PySpark 中插入纬度/经度列以获得轨迹中等距的点

问题描述

我有一个 PySpark 数据框,其中包含由“trajectories_id”列标识的不同轨迹的纬度/经度点。每个轨迹由不同数量的点组成。 下面是一个非常简单的例子。请注意,在我的情况下,我可以有更复杂的(非直线)线。

trajectory_id 纬度 经度
1 45 5
1 45 6
1 45 9
2 46 1
2 46 5

我想要做的是对每个trajectory_id进行插值以获得用户定义的等距点数。请注意,第一个和最后一个点是固定的。在上面的例子中,如果每条线的点数设置为 5,结果将是:

trajectory_id 纬度 经度
1 45 5
1 45 6
1 45 7
1 45 8
1 45 9
2 46 1
2 46 2
2 46 3
2 46 4
2 46 5

使用 Pandas,实现此目的的一种方法可能是使用 shapely 库,转换线串中的每一行,然后使用 numpy 的 linspace 和 shapely 的 interpolate。我想知道在 PySpark 中是否有更有效的方法来实现相同的结果

import pandas as pd
from shapely.geometry import Point,Linestring
import numpy as np
import geopandas as gpd
df = pd.DataFrame([[1,45,5],[1,6],9],[2,46,1],5]],columns=['trajectory_id','latitude','longitude'])
df['point']=df[["longitude","latitude"]].apply(Point,axis=1)
geo_df = pd.DataFrame(columns=['trajectory_id','longitude'])
for i in range(1,df['trajectory_id'].max()+1):
    df_line = Linestring(df['point'][df['trajectory_id']==i].reset_index(drop=True))
    distances = np.linspace(0,df_line.length,5)
    df_points = gpd.GeoDataFrame([df_line.interpolate(distance) for distance in distances],columns=['geometry'])
    df_points['longitude'] = df_points.geometry.x
    df_points['latitude'] = df_points.geometry.y
    df_points['trajectory_id'] = i
    geo_df=geo_df.append(df_points)
del geo_df['geometry']

解决方法

通过使用 UDF 生成一系列整数,我可以生成您预期的数据。它将适用于纬度和经度的范围(您的示例仅显示一系列经度)

def nums(f,t):
    return list(range(f,t + 1))

(df
    .groupBy('trajectory_id') # Group by `trajectory_id` to get min max lat and lon
    .agg(
        F.min('lat').alias('min_lat'),F.max('lat').alias('max_lat'),F.min('lon').alias('min_lon'),F.max('lon').alias('max_lon'),)
    .withColumn('lat_arr',F.udf(nums,T.ArrayType(T.IntegerType()))('min_lat','max_lat')) # using UDF to generate list of lat values based on min max lat
    .withColumn('lon_arr',T.ArrayType(T.IntegerType()))('min_lon','max_lon')) # using UDF to generate list of lon values based on min max lon
    .withColumn('lat',F.explode('lat_arr')) # break down lat values to multiple rows
    .withColumn('lon',F.explode('lon_arr')) # break down lon values to multiple rows
    .drop('min_lat','max_lat','min_lon','max_lon','lat_arr','lon_arr') # drop temporary columns
    .show()
)

# +-------------+---+---+
# |trajectory_id|lat|lon|
# +-------------+---+---+
# |            1| 45|  5|
# |            1| 45|  6|
# |            1| 45|  7|
# |            1| 45|  8|
# |            1| 45|  9|
# |            2| 46|  1|
# |            2| 46|  2|
# |            2| 46|  3|
# |            2| 46|  4|
# |            2| 46|  5|
# +-------------+---+---+