论文辅助笔记:处理geolife数据

论文笔记:Context-aware multi-head self-attentional neural network model fornext location prediction-CSDN博客

对应命令行里

bash 复制代码
    python preprocessing/geolife.py 20

这一句

1 读取geolife数据

python 复制代码
pfs, _ = read_geolife(config["raw_geolife"], print_progress=True)

2 生成staypoint 数据

根据geolife数据,使用滑动窗口的方法获取staypoint

同时geolife DataFrame加一列staypoint

python 复制代码
pfs, sp = pfs.as_positionfixes.generate_staypoints(
        gap_threshold=24 * 60, 
        include_last=True, 
        print_progress=True, 
        dist_threshold=200, 
        time_threshold=30, 
        n_jobs=-1
    )

2.1 判断staypoint是否是活动对应的staypoint

python 复制代码
sp = sp.as_staypoints.create_activity_flag(
    method="time_threshold", 
    time_threshold=25)

如果staypoint停留时间>25min,那么是为一个活跃的staypoint

3 在两个stypoint之间的部分创建行程段

在两个stypoint之间的部分创建行程段

【如果两个非staypoint之间的时间间隔大于阈值的话,视为两个行程段】

python 复制代码
pfs, tpls = pfs.as_positionfixes.generate_triplegs(sp)

4 根据停留点和行程段创建trip数据集

python 复制代码
sp, tpls, trips = generate_trips(sp, tpls, add_geometry=False)

staypoint之前的trip_id,之后的trip_id

行程和行程的trip_id

行程和行程的始末staypoint_id

5 每个用户时间跟踪质量相关内容

python 复制代码
quality_path = os.path.join(".", "data", "quality")
quality_file = os.path.join(quality_path, "geolife_slide_filtered.csv")
quality_filter = {"day_filter": 50, "window_size": 10}
valid_user = calculate_user_quality(
    sp.copy(), 
    trips.copy(), 
    quality_file, 
    quality_filter)
'''

array([  0,   2,   3,   4,   5,  10,  11,  12,  13,  14,  15,  17,  20,
        22,  24,  25,  26,  28,  30,  34,  35,  36,  37,  38,  39,  40,
        41,  42,  43,  44,  45,  46,  50,  51,  52,  55,  56,  58,  59,
        62,  63,  65,  66,  67,  68,  71,  73,  74,  78,  81,  82,  83,
        84,  85,  89,  91,  92,  95,  96,  97,  99, 101, 102, 104, 110,
       111, 112, 114, 115, 119, 122, 125, 126, 128, 130, 131, 132, 133,
       134, 140, 142, 144, 147, 153, 155, 163, 167, 168, 172, 174, 179,
       181], dtype=int64)
'''
'''
        
        valid_user------有记录天数大于day_filter天的那些user_id
        这个函数同时返回了一个csv文件,记录了这些user_id的的时间跟踪质量
    
'''

这个函数的原理如下:

5.0 准备部分

python 复制代码
trips["started_at"] = pd.to_datetime(trips["started_at"]).dt.tz_localize(None)
trips["finished_at"] = pd.to_datetime(trips["finished_at"]).dt.tz_localize(None)
sp["started_at"] = pd.to_datetime(sp["started_at"]).dt.tz_localize(None)
sp["finished_at"] = pd.to_datetime(sp["finished_at"]).dt.tz_localize(None)

sp["type"] = "sp"
trips["type"] = "tpl"
df_all = pd.concat([sp, trips])

5.1 横跨多天的staypoint/行程段进行拆分

python 复制代码
df_all = _split_overlaps(df_all, granularity="day")
'''
如果一个trips/staypoint横跨多天了,那么就拆分成两个trips/staypoint
'''

5.2 更新每个trips/staypoint的持续时

python 复制代码
df_all["duration"] = (df_all["finished_at"] - df_all["started_at"]).dt.total_seconds()

5.3 计算每个用户的时间跟踪质量

trackintel 笔记:generate_staypoints,create_activity_flag-CSDN博客

python 复制代码
total_quality = temporal_tracking_quality(df_all, granularity="all")

某一用户从第一条记录~最后一条记录 这一段时间内,有多少比例的时间是在staypoint/trip的范围内的

【但这个好像没啥用?】

5.4 计算每个用户第一条&最后一条记录之间的天数跨度

python 复制代码
total_quality["days"] = (
        df_all.groupby("user_id").apply(lambda x: (x["finished_at"].max() - x["started_at"].min()).days).values
    )

5.5 筛选时间跨度>阈值的user id

python 复制代码
user_filter_day = (
        total_quality.loc[(total_quality["days"] > quality_filter["day_filter"])]
        .reset_index(drop=True)["user_id"]
        .unique()
    )

5.6 user_id,每window_size周的时间跟踪质量

python 复制代码
sliding_quality = (
        df_all.groupby("user_id")
        .apply(_get_tracking_quality, window_size=quality_filter["window_size"])
        .reset_index(drop=True)
    )

5.6.1 _get_tracking_quality

python 复制代码
def _get_tracking_quality(df, window_size):

    weeks = (df["finished_at"].max() - df["started_at"].min()).days // 7
    '''
    一个user有几周有数据
    '''
    start_date = df["started_at"].min().date()

    quality_list = []
    # construct the sliding week gdf
    for i in range(0, weeks - window_size + 1):
        curr_start = datetime.datetime.combine(start_date + datetime.timedelta(weeks=i), datetime.time())
        curr_end = datetime.datetime.combine(curr_start + datetime.timedelta(weeks=window_size), datetime.time())
        #这里window_size=10,也即10周

        # the total df for this time window
        cAll_gdf = df.loc[(df["started_at"] >= curr_start) & (df["finished_at"] < curr_end)]
        #这10周这个用户的记录
        if cAll_gdf.shape[0] == 0:
            continue
        total_sec = (curr_end - curr_start).total_seconds()

        quality_list.append([i, cAll_gdf["duration"].sum() / total_sec])
        #这10周有记录的比例
    ret = pd.DataFrame(quality_list, columns=["timestep", "quality"])
    ret["user_id"] = df["user_id"].unique()[0]
    return ret

5.7 有记录天数大于50天的那些user_id,每window_size周的时间跟踪质量

python 复制代码
filter_after_day = sliding_quality.loc[sliding_quality["user_id"].isin(user_filter_day)]
filter_after_day

5.8 每个筛选后的user_id的平均滑动时间跟踪质量

python 复制代码
filter_after_user_quality = filter_after_day.groupby("user_id", as_index=False)["quality"].mean()

5.9 函数结束

python 复制代码
filter_after_user_quality.to_csv(file_path, index=False)
#平均滑动时间跟踪质量保存至本地
return filter_after_user_quality["user_id"].values
#返回持续时间大于50天的数据

6 筛选staypoint

6.1 筛选在valid_user里面的

python 复制代码
sp = sp.loc[sp["user_id"].isin(valid_user)]

6.2 筛选活跃的

python 复制代码
sp = sp.loc[sp["is_activity"] == True]
sp

7 聚合staypoint(成为station)

python 复制代码
sp, locs = sp.as_staypoints.generate_locations(
        epsilon=50, 
        num_samples=2, 
        distance_metric="haversine", 
        agg_level="dataset", 
        n_jobs=-1
    )

7.1 去除不在station里面的staypoint(因为这个任务是next station prediction)

python 复制代码
sp = sp.loc[~sp["location_id"].isna()].copy()

7.2 station去重

不同user 可能共享一个location,相同位置的location只保留一个

python 复制代码
locs = locs[~locs.index.duplicated(keep="first")]

7.2 将station信息保存至locations_geolife.csv

8 合并时间阈值内的staypoint

python 复制代码
sp_merged = sp.as_staypoints.merge_staypoints(
        triplegs=pd.DataFrame([]), 
        max_time_gap="1min", 
        agg={"location_id": "first"}
    )

如果两个停留点之间的最大持续时间小于1分钟,则进行合并

9 每个staypoint的持续时间

python 复制代码
sp_merged["duration"] = (sp_merged["finished_at"] - sp_merged["started_at"]).dt.total_seconds() // 60

10 添加和计算新的时间相关字段

python 复制代码
sp_time = enrich_time_info(sp_merged)
sp_time

10.1 enrich_time_info(sp)

python 复制代码
def enrich_time_info(sp):
    sp = sp.groupby("user_id", group_keys=False).apply(_get_time)
    #使用 groupby 根据 user_id 对数据进行分组,并应用辅助函数 _get_time 处理每个组的数据。

    sp.drop(columns={"finished_at", "started_at"}, inplace=True)
    #删除 finished_at 和 started_at 列

    sp.sort_values(by=["user_id", "start_day", "start_min"], inplace=True)
    #对数据进行排序

    sp = sp.reset_index(drop=True)

    # 
    sp["location_id"] = sp["location_id"].astype(int)
    sp["user_id"] = sp["user_id"].astype(int)

    # final cleaning, reassign ids
    sp.index.name = "id"
    sp.reset_index(inplace=True)
    return sp

10.2 _get_time(df)

python 复制代码
def _get_time(df):
    min_day = pd.to_datetime(df["started_at"].min().date())
    #将 started_at 的最小日期(min_day)作为基准点,用于计算其他时间点相对于此日期的差异

    df["started_at"] = df["started_at"].dt.tz_localize(tz=None)
    df["finished_at"] = df["finished_at"].dt.tz_localize(tz=None)

    df["start_day"] = (df["started_at"] - min_day).dt.days
    df["end_day"] = (df["finished_at"] - min_day).dt.days
    #计算 start_day 和 end_day 字段,这两个字段表示相对于 min_day 的天数差异。

    df["start_min"] = df["started_at"].dt.hour * 60 + df["started_at"].dt.minute
    df["end_min"] = df["finished_at"].dt.hour * 60 + df["finished_at"].dt.minute
    #计算 start_min 和 end_min 字段,这些字段表示一天中的分钟数,用于精确到分钟的时间差异计算

    df.loc[df["end_min"] == 0, "end_min"] = 24 * 60
    #如果 end_min 等于 0,表示结束时间为午夜,为了避免计算错误,手动将其设置为 1440(即24小时*60分钟)

    df["weekday"] = df["started_at"].dt.weekday
    #计算 weekday 字段,表示 started_at 所在的星期几(0代表星期一,6代表星期日)

    return df

11 sp_time 存入sp_time_temp_geolife.csv

12 _filter_sp_history(sp_time)

这一部分写的有点繁琐,有一些语句都是没有必要的,我精简一下

12.0 辅助函数

12.0.1 split_dataset

12.0.2 get_valid_sequence

12.1 划分训练、验证、测试集

python 复制代码
train_data, vali_data, test_data = split_dataset(sp_time)

每一个user 前60%天 训练,中间20%天验证,后20%天测试

12.2 获取所有"valid"的行id

所谓valid,指的是那些在给定时间窗口(previous_day 天,这里例子中是7天)内,在当前记录之前有至少三条记录的行ID

python 复制代码
previous_day_ls = [7]
all_ids = sp[["id"]].copy()

for previous_day in previous_day_ls:
        valid_ids = get_valid_sequence(train_data, previous_day=previous_day)
        valid_ids.extend(get_valid_sequence(vali_data, previous_day=previous_day))
        valid_ids.extend(get_valid_sequence(test_data, previous_day=previous_day))

        all_ids[f"{previous_day}"] = 0
        all_ids.loc[all_ids["id"].isin(valid_ids), f"{previous_day}"] = 1
python 复制代码
all_ids.set_index("id", inplace=True)
final_valid_id = all_ids.loc[all_ids.sum(axis=1) == all_ids.shape[1]].reset_index()["id"].values
#这一行写的很复杂,其实就是
'''
all_ids[all_ids['7']==1].index.values
'''

筛选所有valid的行id

12.3 筛选train、valid、test中valid的行对应的user_id

python 复制代码
valid_users_train = train_data.loc[train_data["id"].isin(final_valid_id), "user_id"].unique()
valid_users_vali = vali_data.loc[vali_data["id"].isin(final_valid_id), "user_id"].unique()
valid_users_test = test_data.loc[test_data["id"].isin(final_valid_id), "user_id"].unique()
valid_users_train

12.4 在train、test、valid上都有的user

python 复制代码
valid_users = set.intersection(set(valid_users_train), set(valid_users_vali), set(valid_users_test))

len(valid_users)
#47

12.5 筛选对应的staypoint

python 复制代码
filtered_sp = sp_time.loc[sp_time["user_id"].isin(valid_users)].copy()

12.5 valid_user_id和staypoint 分别保存

python 复制代码
data_path = f"./data/valid_ids_geolife.pk"
with open(data_path, "wb") as handle:
        pickle.dump(final_valid_id, handle, protocol=pickle.HIGHEST_PROTOCOL)
filtered_sp.to_csv(f"./data/dataset_geolife.csv", index=False)
相关推荐
你要飞4 小时前
Hexo + Butterfly 博客添加 Live2D 看板娘指南
笔记
ajsbxi7 小时前
【Java 基础】核心知识点梳理
java·开发语言·笔记
呱呱巨基7 小时前
vim编辑器
linux·笔记·学习·编辑器·vim
新子y8 小时前
【小白笔记】普通二叉树(General Binary Tree)和二叉搜索树的最近公共祖先(LCA)
开发语言·笔记·python
聪明的笨猪猪8 小时前
Java JVM “调优” 面试清单(含超通俗生活案例与深度理解)
java·经验分享·笔记·面试
爱学习的uu8 小时前
CURSOR最新使用指南及使用思路
人工智能·笔记·python·软件工程
YuCaiH8 小时前
Linux文件处理
linux·笔记·嵌入式
Cathy Bryant8 小时前
大模型损失函数(二):KL散度(Kullback-Leibler divergence)
笔记·神经网络·机器学习·数学建模·transformer
qq_398586549 小时前
Threejs入门学习笔记
javascript·笔记·学习
hour_go9 小时前
TCP/IP协议相关知识点
网络·笔记·网络协议·tcp/ip