Issue with Step 5: Retrieving Data Fields from Predefined Cohort Using Spark

Georgia Anne Brice

Hi everyone,

I hope you’re doing well. I’m new to using Spark and the UK Biobank Research Analysis Platform (RAP) and would appreciate any troubleshooting advice.

I created a predefined cohort using the Cohort Browser (MS diagnosis from ICD-10/ICD-9 codes) and adapted the prologue (https://github.com/dnanexus/OpenBio/blob/master/UKB_notebooks/ukb-rap-pheno-basic.ipynb) to extract ~100 fields for ~2,000 participants. I followed the RAP documentation (https://dnanexus.gitbook.io/uk-biobank-rap/working-on-the-research-analysis-platform/accessing-data/using-spark-to-analyze-tabular-data).

My setup:

  • Priority: High
  • Cluster Configuration: Spark Cluster
  • Instance Type: mem2_ssd1_v2_x16
  • Nodes: 8
  • Feature: Hail (default)

My PySpark syntax for Python Notebook:

## Step 1: Preliminary Prologue Steps ##

# Import required packages

import pyspark

import dxpy

import dxdata

dxdata.__version__

# Spark initialisation (Done only once; do not rerun this cell unless you select Kernel -> Restart kernel).

sc = pyspark.SparkContext()

spark = pyspark.sql.SparkSession(sc)

# Automatically discover dispensed database name and dataset id

dispensed_database = dxpy.find_one_data_object(

    classname='database', 

    name='app*', 

    folder='/', 

    name_mode='glob', 

    describe=True)

dispensed_database_name = dispensed_database['describe']['name']

dispensed_dataset = dxpy.find_one_data_object(

    typename='Dataset', 

    name='app*.dataset', 

    folder='/', 

    name_mode='glob')

dispensed_dataset_id = dispensed_dataset['id']

 

## Step 2: Access Dataset and Entities ##

# Load dataset

dataset = dxdata.load_dataset(id=dispensed_dataset_id)

# Display dataset entities (e.g., participant data, hospital records, GP records)

dataset.entities 

# Access main entities

participant = dataset["participant"]  # Participant data

hesin = dataset["hesin"]  # Hospital records

 

## Step 3: Define Fields of Interest ##

def field_names_for_id(field_id):

    """

    Returns all field names for a given UKB showcase field ID,

    including all instances and arrays if they exist.

    """

    from distutils.version import LooseVersion

    fields = participant.find_fields(name_regex=r'^p{}(_i\d+)?(_a\d+)?$'.format(field_id))

    return sorted([f.name for f in fields], key=LooseVersion)

# Core Identifiers

field_names = [

    "eid",  # Participant ID

    "p31",  # Sex

    "p21022",  # Age at recruitment

    "p34",  # Year of birth

]

# ICD Diagnoses

field_names.extend(field_names_for_id("p41270"))  # ICD10 diagnosis (with instances)

field_names.extend(field_names_for_id("p41271"))  # ICD9 diagnosis (with instances)

 

# Add all instances of ICD diagnosis (Date of first in-patient diagnosis)

field_names.extend(field_names_for_id("p41280"))  # ICD10 date

field_names.extend(field_names_for_id("p41281"))  # ICD9 date

# Diet-related variables

field_names.extend([

    "p1329_i0",  # FFQ Oily fish intake

    "p1339_i0",  # FFQ Non-oily fish intake

    "p103140_i0", "p103140_i1", "p103140_i2", "p103140_i3", "p103140_i4",  # Fish consumer

    "p103150_i0", "p103150_i1", "p103150_i2", "p103150_i3", "p103150_i4",  # Tinned tuna intake

    "p103160_i0", "p103160_i1", "p103160_i2", "p103160_i3", "p103160_i4",  # Oily fish intake

    "p103170_i0", "p103170_i1", "p103170_i2", "p103170_i3", "p103170_i4",  # Breaded fish intake

    "p103180_i0", "p103180_i1", "p103180_i2", "p103180_i3", "p103180_i4",  # Battered fish intake

    "p103190_i0", "p103190_i1", "p103190_i2", "p103190_i3", "p103190_i4",  # White fish intake

    "p103200_i0", "p103200_i1", "p103200_i2", "p103200_i3", "p103200_i4",  # Prawns intake

    "p103210_i0", "p103210_i1", "p103210_i2", "p103210_i3", "p103210_i4",  # Lobster/crab intake

    "p103220_i0", "p103220_i1", "p103220_i2", "p103220_i3", "p103220_i4",  # Shellfish intake

    "p103230_i0", "p103230_i1", "p103230_i2", "p103230_i3", "p103230_i4",  # Other fish intake

    "p105010_i0", "p105010_i1", "p105010_i2", "p105010_i3", "p105010_i4",  # When diet questionnaire completed

    "p26002_i0", "p26002_i1", "p26002_i2", "p26002_i3", "p26002_i4",  # Energy intake

    "p100020_i0", "p100020_i1", "p100020_i2", "p100020_i3", "p100020_i4",  # Typical diet yesterday

    "p100010_i0", "p100010_i1", "p100010_i2", "p100010_i3", "p100010_i4",  # Portion size

])

# PHQ-9, GAD-7, and 4 baseline mental health questions

field_names.extend([

    "p20514", "p20510", "p20517", "p20519", "p20511", "p20507", "p20508", "p20518", 

    "p20513", #PHQ-9

    "p20506", "p20509", "p20520", "p20515", "p20516", "p20505", "p20512", #GAD-7

    "p2050_i0", "p2060_i0", "p2080_i0", "p2070_i0" #4 baseline mental health questions

])

# Additional covariates

field_names.extend([

    "p22189",  # Townsend deprivation index

    "p20116_i0",  # Smoking status

    "p21002_i0",  # Weight

    "p50_i0",  # Standing height

    "p22032_i0",  # IPAQ activity group

    "p3140_i0", "p3140_i1", "p3140_i2", "p3140_i3", # Pregnancy status

    "p21001_i0",  # BMI

    "p53_i0"  # Date of attending assessment centre

])

## Step 4: Load Predefined Cohort ##

# Load predefined MS cohort from the "Fish_cohorts" folder

cohort = dxdata.load_cohort("/Fish_project/Fish_cohorts/MS_combined_cohort_primary_analysis_17-03-25")

 

## Step 5: Retrieving fields using batch processing ##

# Import pandas first 

import pandas as pd

# Create a connection

connection = dxdata.connect()

# Optimised batch retrieval function using Spark DataFrames

def retrieve_batch_fields(batch_idx, field_batch):

    """

    Retrieves a batch of fields and returns as Spark DataFrame

    """

    print(f"Starting batch {batch_idx} with {len(field_batch)} fields...")

    # Always ensure 'eid' is included in each batch

    if "eid" not in field_batch and batch_idx > 0:

        field_batch = ["eid"] + field_batch

    # Print fields being retrieved

    print(f"Fields in this batch: {field_batch}")   

    # Retrieve fields with cohort filtering

    try:

        df_batch = cohort.retrieve_fields(names=field_batch, engine=connection)       

        # Force execution to detect problems early

        row_count = df_batch.count()

        print(f"Retrieved {row_count} rows for batch {batch_idx}")

        return df_batch  # Return Spark DataFrame directly

    except Exception as e:

        print(f"ERROR in batch {batch_idx}: {str(e)}")

        return None

# Define batch size

batch_size = 10  # 

# Initialise final DataFrame

final_spark_df = None

# Calculate total number of batches

total_batches = (len(field_names) + batch_size - 1) // batch_size

print(f"Processing {len(field_names)} fields in {total_batches} batches...")

# Track successful batches

successful_batches = set()

failed_batches = []

# Process in batches

for batch_num in range(total_batches):

    start_idx = batch_num * batch_size

    end_idx = min((batch_num + 1) * batch_size, len(field_names))

    # Skip if already processed

    if batch_num in successful_batches:

        print(f"Skipping batch {batch_num} (already processed)")

        continue

    # Extract current batch fields

    current_fields = field_names[start_idx:end_idx]

    # Process batch

    print(f"Processing batch {batch_num+1}/{total_batches} ({start_idx}-{end_idx})")

    batch_df = retrieve_batch_fields(batch_num, current_fields)   

    # If batch retrieval failed

    if batch_df is None:

        failed_batches.append(batch_num)

        print(f"Adding batch {batch_num} to failed batches list. Will retry later.")

        continue

    # Handle first batch vs subsequent batches

    if final_spark_df is None:

        final_spark_df = batch_df

        print("First batch becomes base dataset")

    else:

        # Join in Spark instead of Pandas

        print("Joining with existing data...")

        final_spark_df = final_spark_df.join(batch_df, on="eid", how="outer")

        print(f"Joined dataset now has {len(final_spark_df.columns)} columns")

    # Mark as successfully processed

    successful_batches.add(batch_num)

    # Cache results periodically to optimise performance

    if (batch_num + 1) % 5 == 0:

        print("Caching intermediate results...")

        final_spark_df.cache().count()

    # Save checkpoint - now using Spark's native parquet format

    if (batch_num + 1) % 5 == 0 or batch_num == total_batches - 1:

        checkpoint_path = f'/tmp/checkpoint_batch_{batch_num+1}'

        print(f"Saving checkpoint to {checkpoint_path}")

       final_spark_df.write.mode("overwrite").parquet(checkpoint_path)

# Report on any failed batches

if failed_batches:

    print(f"WARNING: {len(failed_batches)} batches failed: {failed_batches}")

else:

    print("All batches processed successfully!")

# Convert final Spark DataFrame to Pandas at the end

print("Converting final dataset to Pandas...")

final_df = final_spark_df.toPandas()

# Final dataset stats

print(f"Final dataset: {final_df.shape[0]} rows × {final_df.shape[1]} columns")

 

## Step 6: Save and Upload Data ##

# Save as CSV with today’s date

final_df.to_csv('MS_combined_cohort_primary_analysis_with_data_fields_19-03-25.csv', index=False) 

# Upload CSV to the correct folder in the project with today’s date

%%bash

dx upload MS_combined_cohort_primary_analysis_with_data_fields_19-03-25.csv --path /Fish_project/Fish_cohorts/

# Upload notebook

%%bash

dx upload Spark_JupyterLab_Creating_Primary_Dataset_19-03-25.ipynb --path /Fish_project/Spark_notebooks/

Steps 1–4 executed successfully. However, Step 5 (retrieving data fields) keeps stalling, even when adjusting the approach:

  • Initial attempt: Retrieving all fields at once:

# Create a connection

connection = dxdata.connect()

df = cohort.retrieve_fields(names=field_names, engine=connection)

df_pandas = df.toPandas() 

print("Cohort size:", df_pandas.shape[0])

  • Modified attempts: Batches of 10 (as above), 5, and 1 field at a time (adapted the above for each).

The log shows no clear error messages except: 

opt/conda/lib/python3.11/site-packages/thrift/transport/TSSLSocket.py:53: DeprecationWarning: ssl.PROTOCOL_TLS is deprecated self._context = ssl.SSLContext(ssl_version) for step 5. 

Despite reducing batch size, the issue persists, where CPU drops to 0% and the first variable (eid) is not retrieved.

Has anyone encountered similar issues with retrieving multiple fields? Could this be related to cluster configuration, memory limits, or the way the cohort is loaded? Any suggestions would be greatly appreciated!

Thank you in advance!

Comments

0 comments

Please sign in to leave a comment.