Issue with Step 5: Retrieving Data Fields from Predefined Cohort Using Spark
Hi everyone,
I hope you’re doing well. I’m new to using Spark and the UK Biobank Research Analysis Platform (RAP) and would appreciate any troubleshooting advice.
I created a predefined cohort using the Cohort Browser (MS diagnosis from ICD-10/ICD-9 codes) and adapted the prologue (https://github.com/dnanexus/OpenBio/blob/master/UKB_notebooks/ukb-rap-pheno-basic.ipynb) to extract ~100 fields for ~2,000 participants. I followed the RAP documentation (https://dnanexus.gitbook.io/uk-biobank-rap/working-on-the-research-analysis-platform/accessing-data/using-spark-to-analyze-tabular-data).
My setup:
- Priority: High
- Cluster Configuration: Spark Cluster
- Instance Type: mem2_ssd1_v2_x16
- Nodes: 8
- Feature: Hail (default)
My PySpark syntax for Python Notebook:
## Step 1: Preliminary Prologue Steps ##
# Import required packages
import pyspark
import dxpy
import dxdata
dxdata.__version__
# Spark initialisation (Done only once; do not rerun this cell unless you select Kernel -> Restart kernel).
sc = pyspark.SparkContext()
spark = pyspark.sql.SparkSession(sc)
# Automatically discover dispensed database name and dataset id
dispensed_database = dxpy.find_one_data_object(
classname='database',
name='app*',
folder='/',
name_mode='glob',
describe=True)
dispensed_database_name = dispensed_database['describe']['name']
dispensed_dataset = dxpy.find_one_data_object(
typename='Dataset',
name='app*.dataset',
folder='/',
name_mode='glob')
dispensed_dataset_id = dispensed_dataset['id']
## Step 2: Access Dataset and Entities ##
# Load dataset
dataset = dxdata.load_dataset(id=dispensed_dataset_id)
# Display dataset entities (e.g., participant data, hospital records, GP records)
dataset.entities
# Access main entities
participant = dataset["participant"] # Participant data
hesin = dataset["hesin"] # Hospital records
## Step 3: Define Fields of Interest ##
def field_names_for_id(field_id):
"""
Returns all field names for a given UKB showcase field ID,
including all instances and arrays if they exist.
"""
from distutils.version import LooseVersion
fields = participant.find_fields(name_regex=r'^p{}(_i\d+)?(_a\d+)?$'.format(field_id))
return sorted([f.name for f in fields], key=LooseVersion)
# Core Identifiers
field_names = [
"eid", # Participant ID
"p31", # Sex
"p21022", # Age at recruitment
"p34", # Year of birth
]
# ICD Diagnoses
field_names.extend(field_names_for_id("p41270")) # ICD10 diagnosis (with instances)
field_names.extend(field_names_for_id("p41271")) # ICD9 diagnosis (with instances)
# Add all instances of ICD diagnosis (Date of first in-patient diagnosis)
field_names.extend(field_names_for_id("p41280")) # ICD10 date
field_names.extend(field_names_for_id("p41281")) # ICD9 date
# Diet-related variables
field_names.extend([
"p1329_i0", # FFQ Oily fish intake
"p1339_i0", # FFQ Non-oily fish intake
"p103140_i0", "p103140_i1", "p103140_i2", "p103140_i3", "p103140_i4", # Fish consumer
"p103150_i0", "p103150_i1", "p103150_i2", "p103150_i3", "p103150_i4", # Tinned tuna intake
"p103160_i0", "p103160_i1", "p103160_i2", "p103160_i3", "p103160_i4", # Oily fish intake
"p103170_i0", "p103170_i1", "p103170_i2", "p103170_i3", "p103170_i4", # Breaded fish intake
"p103180_i0", "p103180_i1", "p103180_i2", "p103180_i3", "p103180_i4", # Battered fish intake
"p103190_i0", "p103190_i1", "p103190_i2", "p103190_i3", "p103190_i4", # White fish intake
"p103200_i0", "p103200_i1", "p103200_i2", "p103200_i3", "p103200_i4", # Prawns intake
"p103210_i0", "p103210_i1", "p103210_i2", "p103210_i3", "p103210_i4", # Lobster/crab intake
"p103220_i0", "p103220_i1", "p103220_i2", "p103220_i3", "p103220_i4", # Shellfish intake
"p103230_i0", "p103230_i1", "p103230_i2", "p103230_i3", "p103230_i4", # Other fish intake
"p105010_i0", "p105010_i1", "p105010_i2", "p105010_i3", "p105010_i4", # When diet questionnaire completed
"p26002_i0", "p26002_i1", "p26002_i2", "p26002_i3", "p26002_i4", # Energy intake
"p100020_i0", "p100020_i1", "p100020_i2", "p100020_i3", "p100020_i4", # Typical diet yesterday
"p100010_i0", "p100010_i1", "p100010_i2", "p100010_i3", "p100010_i4", # Portion size
])
# PHQ-9, GAD-7, and 4 baseline mental health questions
field_names.extend([
"p20514", "p20510", "p20517", "p20519", "p20511", "p20507", "p20508", "p20518",
"p20513", #PHQ-9
"p20506", "p20509", "p20520", "p20515", "p20516", "p20505", "p20512", #GAD-7
"p2050_i0", "p2060_i0", "p2080_i0", "p2070_i0" #4 baseline mental health questions
])
# Additional covariates
field_names.extend([
"p22189", # Townsend deprivation index
"p20116_i0", # Smoking status
"p21002_i0", # Weight
"p50_i0", # Standing height
"p22032_i0", # IPAQ activity group
"p3140_i0", "p3140_i1", "p3140_i2", "p3140_i3", # Pregnancy status
"p21001_i0", # BMI
"p53_i0" # Date of attending assessment centre
])
## Step 4: Load Predefined Cohort ##
# Load predefined MS cohort from the "Fish_cohorts" folder
cohort = dxdata.load_cohort("/Fish_project/Fish_cohorts/MS_combined_cohort_primary_analysis_17-03-25")
## Step 5: Retrieving fields using batch processing ##
# Import pandas first
import pandas as pd
# Create a connection
connection = dxdata.connect()
# Optimised batch retrieval function using Spark DataFrames
def retrieve_batch_fields(batch_idx, field_batch):
"""
Retrieves a batch of fields and returns as Spark DataFrame
"""
print(f"Starting batch {batch_idx} with {len(field_batch)} fields...")
# Always ensure 'eid' is included in each batch
if "eid" not in field_batch and batch_idx > 0:
field_batch = ["eid"] + field_batch
# Print fields being retrieved
print(f"Fields in this batch: {field_batch}")
# Retrieve fields with cohort filtering
try:
df_batch = cohort.retrieve_fields(names=field_batch, engine=connection)
# Force execution to detect problems early
row_count = df_batch.count()
print(f"Retrieved {row_count} rows for batch {batch_idx}")
return df_batch # Return Spark DataFrame directly
except Exception as e:
print(f"ERROR in batch {batch_idx}: {str(e)}")
return None
# Define batch size
batch_size = 10 #
# Initialise final DataFrame
final_spark_df = None
# Calculate total number of batches
total_batches = (len(field_names) + batch_size - 1) // batch_size
print(f"Processing {len(field_names)} fields in {total_batches} batches...")
# Track successful batches
successful_batches = set()
failed_batches = []
# Process in batches
for batch_num in range(total_batches):
start_idx = batch_num * batch_size
end_idx = min((batch_num + 1) * batch_size, len(field_names))
# Skip if already processed
if batch_num in successful_batches:
print(f"Skipping batch {batch_num} (already processed)")
continue
# Extract current batch fields
current_fields = field_names[start_idx:end_idx]
# Process batch
print(f"Processing batch {batch_num+1}/{total_batches} ({start_idx}-{end_idx})")
batch_df = retrieve_batch_fields(batch_num, current_fields)
# If batch retrieval failed
if batch_df is None:
failed_batches.append(batch_num)
print(f"Adding batch {batch_num} to failed batches list. Will retry later.")
continue
# Handle first batch vs subsequent batches
if final_spark_df is None:
final_spark_df = batch_df
print("First batch becomes base dataset")
else:
# Join in Spark instead of Pandas
print("Joining with existing data...")
final_spark_df = final_spark_df.join(batch_df, on="eid", how="outer")
print(f"Joined dataset now has {len(final_spark_df.columns)} columns")
# Mark as successfully processed
successful_batches.add(batch_num)
# Cache results periodically to optimise performance
if (batch_num + 1) % 5 == 0:
print("Caching intermediate results...")
final_spark_df.cache().count()
# Save checkpoint - now using Spark's native parquet format
if (batch_num + 1) % 5 == 0 or batch_num == total_batches - 1:
checkpoint_path = f'/tmp/checkpoint_batch_{batch_num+1}'
print(f"Saving checkpoint to {checkpoint_path}")
final_spark_df.write.mode("overwrite").parquet(checkpoint_path)
# Report on any failed batches
if failed_batches:
print(f"WARNING: {len(failed_batches)} batches failed: {failed_batches}")
else:
print("All batches processed successfully!")
# Convert final Spark DataFrame to Pandas at the end
print("Converting final dataset to Pandas...")
final_df = final_spark_df.toPandas()
# Final dataset stats
print(f"Final dataset: {final_df.shape[0]} rows × {final_df.shape[1]} columns")
## Step 6: Save and Upload Data ##
# Save as CSV with today’s date
final_df.to_csv('MS_combined_cohort_primary_analysis_with_data_fields_19-03-25.csv', index=False)
# Upload CSV to the correct folder in the project with today’s date
%%bash
dx upload MS_combined_cohort_primary_analysis_with_data_fields_19-03-25.csv --path /Fish_project/Fish_cohorts/
# Upload notebook
%%bash
dx upload Spark_JupyterLab_Creating_Primary_Dataset_19-03-25.ipynb --path /Fish_project/Spark_notebooks/
Steps 1–4 executed successfully. However, Step 5 (retrieving data fields) keeps stalling, even when adjusting the approach:
- Initial attempt: Retrieving all fields at once:
# Create a connection
connection = dxdata.connect()
df = cohort.retrieve_fields(names=field_names, engine=connection)
df_pandas = df.toPandas()
print("Cohort size:", df_pandas.shape[0])
- Modified attempts: Batches of 10 (as above), 5, and 1 field at a time (adapted the above for each).
The log shows no clear error messages except:
opt/conda/lib/python3.11/site-packages/thrift/transport/TSSLSocket.py:53: DeprecationWarning: ssl.PROTOCOL_TLS is deprecated self._context = ssl.SSLContext(ssl_version) for step 5.
Despite reducing batch size, the issue persists, where CPU drops to 0% and the first variable (eid) is not retrieved.
Has anyone encountered similar issues with retrieving multiple fields? Could this be related to cluster configuration, memory limits, or the way the cohort is loaded? Any suggestions would be greatly appreciated!
Thank you in advance!
Comments
0 comments
Please sign in to leave a comment.