# Copyright 2023-2024 Curtin University
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Author: Keegan Smith
from __future__ import annotations
import csv
import logging
import os
import re
import time
from concurrent.futures import as_completed, ProcessPoolExecutor, ThreadPoolExecutor
import datetime
import json
import shutil
from typing import Dict, Optional, Tuple, Union
import pendulum
import xmltodict
from airflow import AirflowException
from airflow.exceptions import AirflowSkipException
from airflow.hooks.base import BaseHook
from airflow.models import DagRun
from google import auth
from google.auth.compute_engine import Credentials as ComputeEngineCredentials
from google.oauth2.service_account import Credentials as ServiceAccountCredentials
from google.cloud import bigquery
from google.cloud.bigquery import SourceFormat
from academic_observatory_workflows.s5cmd import S5Cmd
from academic_observatory_workflows.orcid_telescope.release import OrcidRelease
from academic_observatory_workflows.orcid_telescope.batch import OrcidBatch
from observatory_platform.airflow.airflow import is_first_dag_run
from observatory_platform.airflow.workflow import CloudWorkspace, cleanup
from observatory_platform.dataset_api import DatasetRelease, DatasetAPI
from observatory_platform.date_utils import datetime_normalise
from observatory_platform.files import change_keys, save_jsonl_gz
from observatory_platform.google import bigquery as bq
from observatory_platform.google.gke import gke_service_account_email
from observatory_platform.google.gcs import (
gcs_blob_uri,
gcs_create_aws_transfer,
gcs_hmac_key,
gcs_list_blobs,
gcs_upload_files,
)
[docs]ORCID_AWS_SUMMARIES_BUCKET = "v2.0-summaries"
[docs]ORCID_REGEX = r"\d{4}-\d{4}-\d{4}-\d{3}(\d|X)\b"
[docs]def fetch_release(
dag_id: str,
run_id: str,
dag_run: DagRun,
data_interval_start: pendulum.DateTime,
data_interval_end: pendulum.DateTime,
cloud_workspace: CloudWorkspace,
api_bq_dataset_id: str,
bq_dataset_id: str,
bq_main_table_name: str,
bq_upsert_table_name: str,
bq_delete_table_name: str,
) -> dict:
"""Generates the OrcidRelease object.
:param dag_id: The ID of the dag
:param run_id: The ID of this dag run
:param dag_run: This dag run's DagRun object
:param data_interval_start: The start of the data interval
:param data_interval_end: The end of the data interval
:param cloud_workspace: The CloudWorkspace object
:param api_bq_dataset_id: The bigquery dataset ID for the API
:param bq_dataset_id: The bigquery dataset ID for the telescope
:param bq_main_table_name: The name of the main bigquery table
:param bq_upsert_table_name: The name of the upsert table
:param bq_delete_table_name: The name of the delete table
"""
api = DatasetAPI(bq_project_id=cloud_workspace.project_id, bq_dataset_id=api_bq_dataset_id)
releases = api.get_dataset_releases(dag_id=dag_id, entity_id="orcid", date_key="changefile_end_date")
is_first_run = is_first_dag_run(dag_run)
# Determine the modification cutoff for the new release
if is_first_run:
if not len(releases) == 0:
raise ValueError(
"fetch_releases: there should be no DatasetReleases stored in the Observatory API on the first DAG run."
)
prev_latest_modified_record = pendulum.instance(datetime.datetime.min)
prev_release_end = pendulum.instance(datetime.datetime.min)
else:
if not len(releases) >= 1:
raise ValueError(
"fetch_releases: there should be at least 1 DatasetRelease in the Observatory API after the first DAG run"
)
prev_release = api.get_latest_dataset_release(dag_id=dag_id, entity_id="orcid", date_key="changefile_end_date")
prev_latest_modified_record = pendulum.parse(prev_release.extra["latest_modified_record_date"])
logging.info(f"Proceeding with previous latest modified record date: {prev_latest_modified_record}")
prev_release_end = prev_release.changefile_end_date
return OrcidRelease(
dag_id=dag_id,
run_id=run_id,
cloud_workspace=cloud_workspace,
bq_dataset_id=bq_dataset_id,
bq_main_table_name=bq_main_table_name,
bq_upsert_table_name=bq_upsert_table_name,
bq_delete_table_name=bq_delete_table_name,
prev_release_end=prev_release_end,
start_date=data_interval_start,
end_date=data_interval_end,
prev_latest_modified_record=prev_latest_modified_record,
is_first_run=is_first_run,
).to_dict()
[docs]def create_dataset(release: dict, dataset_description: str) -> None:
"""Create datasets
:param release: The Orcid Release object
:param dataset_description: The description to give the Orcid dataset
"""
release = OrcidRelease.from_dict(release)
bq.bq_create_dataset(
project_id=release.cloud_workspace.project_id,
dataset_id=release.bq_dataset_id,
location=release.cloud_workspace.data_location,
description=dataset_description,
)
[docs]def transfer_orcid(
release: dict, aws_orcid_conn_id: str, transfer_attempts: int, orcid_bucket: str, orcid_summaries_prefix: str
):
"""Sync files from AWS bucket to Google Cloud bucket.
:param aws_orcid_conn_id: The airflow connection ID for the AWS ORCID bucket
:param transfer_attempts: The number of times to attempt the transfer job before giving up
:param orcid_bucket: The name of the gcs bucket to store the Orcid data
:param orcid_summaries_prefix: The prefix in which to store the summaries in the bucket
"""
release = OrcidRelease.from_dict(release)
success = False
aws_key = aws_orcid_key(aws_orcid_conn_id)
for i in range(transfer_attempts):
logging.info(f"Beginning AWS to GCP transfer attempt no. {i+1}")
success, objects_count = gcs_create_aws_transfer(
aws_key=aws_key,
aws_bucket=ORCID_AWS_SUMMARIES_BUCKET,
include_prefixes=[],
gc_project_id=release.cloud_workspace.project_id,
gc_bucket_dst_uri=gcs_blob_uri(orcid_bucket, f"{orcid_summaries_prefix}/"),
description="Transfer ORCID data from AWS to GCP",
)
logging.info(f"Attempt {i+1}: Total number of objects transferred: {objects_count}")
logging.info(f"Attempt {i+1}: Success? {success}")
if success:
break
if not success:
raise AirflowException("Error transferring data. See GCP logs for details.")
[docs]def bq_create_main_table_snapshot(release: dict, snapshot_expiry_days: int):
"""Create a snapshot of each main table. The purpose of this table is to be able to rollback the table
if something goes wrong. The snapshot expires after snapshot_expiry_days.
:param snapshot_expiry_days: The number of days to keep the snapshot
"""
release = OrcidRelease.from_dict(release)
if release.is_first_run:
msg = f"bq_create_main_table_snapshots: skipping as snapshots are not created on the first run"
logging.info(msg)
raise AirflowSkipException(msg)
expiry_date = pendulum.now().add(days=snapshot_expiry_days)
success = bq.bq_snapshot(
src_table_id=release.bq_main_table_id,
dst_table_id=release.bq_snapshot_table_id,
expiry_date=expiry_date,
)
if not success:
raise AirflowException("Error occurred on table creation.")
[docs]def create_manifests(
release: dict, orcid_bucket: str, orcid_summaries_prefix: str, max_workers: Optional[int] = None
) -> str:
"""Create a manifest of all the modified files in the orcid bucket.
:param orcid_bucket: The name of the gcs bucket to store the Orcid data
:param orcid_summaries_prefix: The prefix in which to store the summaries in the bucket
:param max_workers: Number of process pools to use. If None, uses the cpu count
:return: The datetime of the last modified ORCID record in the manifest as a string
"""
release = OrcidRelease.from_dict(release)
logging.info("Creating manifest")
if not max_workers:
max_workers = os.cpu_count() * 2
logging.info(f"Using {max_workers} threads for manifest creation")
orcid_batches = release.orcid_batches()
with ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = []
for orcid_batch in orcid_batches:
futures.append(
executor.submit(
create_orcid_batch_manifest,
orcid_batch=orcid_batch,
reference_date=release.prev_latest_modified_record,
bucket=orcid_bucket,
bucket_prefix=orcid_summaries_prefix,
)
)
for future in as_completed(futures):
future.result()
# Open and write each batch manifest to the master manifest file
logging.info("Joining manifest files")
with open(release.master_manifest_file, "w") as f:
writer = csv.DictWriter(f, fieldnames=MANIFEST_HEADER)
writer.writeheader()
for orcid_batch in orcid_batches:
with open(orcid_batch.manifest_file, "r") as df:
reader = csv.DictReader(df)
writer.writerows(reader)
[docs]def download(release: dict):
"""Reads each batch's manifest and downloads the files from the gcs bucket."""
release = OrcidRelease.from_dict(release)
# Check the type of credentials
credentials, _ = auth.default()
if isinstance(credentials, ServiceAccountCredentials):
email = credentials.service_account_email
logging.info("Using service account credentials")
elif isinstance(credentials, ComputeEngineCredentials):
logging.info("Using compute engine credentials")
email = gke_service_account_email()
if not email:
raise AirflowException("Email could not be retrieved")
else:
raise AirflowException(f"Unknown credentials type: {type(credentials)}")
with gcs_hmac_key(release.cloud_workspace.project_id, email) as (key, secret):
total_files = 0
start_time = time.time()
for orcid_batch in release.orcid_batches():
if not orcid_batch.missing_records:
total_files += len(orcid_batch.existing_records)
logging.info(f"All files present for {orcid_batch.batch_str}. Skipping download.")
continue
logging.info(f"Downloading files for ORCID directory: {orcid_batch.batch_str}")
with open(orcid_batch.download_log_file, "w") as f:
s5cmd = S5Cmd(access_credentials=(key.access_id, secret), out_stream=f)
returncode = s5cmd.download_from_bucket(
uris=orcid_batch.blob_uris, local_path=orcid_batch.download_batch_dir
)[-1]
# Check for errors
if returncode != 0:
with open(orcid_batch.download_log_file) as f:
output = f.read()
logging.error(output)
raise RuntimeError(
f"Download attempt '{orcid_batch.batch_str}': returned non-zero exit code: {returncode}. See log file: {orcid_batch.download_log_file}"
)
if orcid_batch.missing_records:
raise FileNotFoundError(f"All files were not downloaded for {orcid_batch.batch_str}. Aborting.")
# Keep track of files downloaded
total_files += len(orcid_batch.expected_records)
logging.info(f"Downloaded {len(orcid_batch.expected_records)} records for batch '{orcid_batch.batch_str}'")
total_time = time.time() - start_time
# Check all files exist
if total_files != len(release.downloaded_records):
raise FileNotFoundError(
f"Downloaded {total_files} files but found {len(release.downloaded_records)} records on disk."
)
logging.info(f"Completed download for {total_files} files in {str(datetime.timedelta(seconds=total_time))}")
[docs]def clean_downloads(release: dict):
"""Removes all files in the downloads folder of the release"""
release = OrcidRelease.from_dict(release)
logging.info(f"Removing folder: {release.download_folder}")
shutil.rmtree(release.download_folder)
[docs]def bq_load_main_table(release: dict, schema_file_path: str):
"""Load the main table.
:param schema_file_path: The path to the schema file to use for table load
"""
release = OrcidRelease.from_dict(release)
if not release.is_first_run:
msg = f"bq_load_main_table: skipping as the main table is only created on the first run"
logging.info(msg)
raise AirflowSkipException(msg)
blobs = gcs_list_blobs(release.cloud_workspace.transform_bucket, match_glob=release.upsert_blob_glob)
success = bq.bq_load_table(
uri=[gcs_blob_uri(blob.bucket.name, blob.name) for blob in blobs],
table_id=release.bq_main_table_id,
schema_file_path=schema_file_path,
source_format=SourceFormat.NEWLINE_DELIMITED_JSON,
write_disposition=bigquery.WriteDisposition.WRITE_EMPTY,
ignore_unknown_values=False,
)
if not success:
raise AirflowException("Error occurred loading data to main table")
[docs]def bq_load_upsert_table(release: dict, schema_file_path: str):
"""Load the upsert table into bigquery
:param schema_file_path: The path to the schema file to use for table load
"""
release = OrcidRelease.from_dict(release)
if release.is_first_run:
msg = f"bq_load_upsert_table: skipping as no records are upserted on the first run"
logging.info(msg)
raise AirflowSkipException(msg)
blobs = gcs_list_blobs(release.cloud_workspace.transform_bucket, match_glob=release.upsert_blob_glob)
success = bq.bq_load_table(
uri=[gcs_blob_uri(blob.bucket.name, blob.name) for blob in blobs],
table_id=release.bq_upsert_table_id,
schema_file_path=schema_file_path,
source_format=SourceFormat.NEWLINE_DELIMITED_JSON,
write_disposition=bigquery.WriteDisposition.WRITE_TRUNCATE,
ignore_unknown_values=True,
)
if not success:
raise AirflowException("Error occurred loading upsert table")
[docs]def bq_load_delete_table(release: dict, delete_schema_file_path: str):
"""Load the delete table into bigquery
:param schema_file_path: The path to the 'delete' schema file to use for table load
"""
release = OrcidRelease.from_dict(release)
if release.is_first_run:
msg = "bq_load_delete_table: skipping as no records are deleted on the first run"
logging.info(msg)
raise AirflowSkipException(msg)
blobs = gcs_list_blobs(release.cloud_workspace.transform_bucket, match_glob=release.delete_blob_glob)
success = bq.bq_load_table(
uri=[gcs_blob_uri(blob.bucket.name, blob.name) for blob in blobs],
table_id=release.bq_delete_table_id,
schema_file_path=delete_schema_file_path,
source_format=SourceFormat.NEWLINE_DELIMITED_JSON,
write_disposition=bigquery.WriteDisposition.WRITE_TRUNCATE,
ignore_unknown_values=True,
)
if not success:
raise AirflowException("Error occurred loading delete table")
[docs]def bq_upsert_records(release: dict):
"""Upsert the records from the upserts table into the main table."""
release = OrcidRelease.from_dict(release)
if release.is_first_run:
msg = "bq_upsert_records: skipping as no records are upserted on the first run"
logging.info(msg)
raise AirflowSkipException(msg)
bq.bq_upsert_records(
main_table_id=release.bq_main_table_id,
upsert_table_id=release.bq_upsert_table_id,
primary_key="orcid_identifier",
)
[docs]def bq_delete_records(release: dict):
"""Delete the records in the delete table from the main table."""
release = OrcidRelease.from_dict(release)
if release.is_first_run:
msg = "bq_delete_records: skipping as no records are deleted on the first run"
logging.info(msg)
raise AirflowSkipException(msg)
bq.bq_delete_records(
main_table_id=release.bq_main_table_id,
delete_table_id=release.bq_delete_table_id,
main_table_primary_key="orcid_identifier.path",
delete_table_primary_key="id",
)
[docs]def add_dataset_release(release: dict, *, api_bq_dataset_id: str, latest_modified_date: str) -> None:
"""Adds release information to API.
:param api_bq_datastet_id: bigquery dataset ID of the API
:param last_modified_release_date: The modification datetime of the last modified record of this release"""
release = OrcidRelease.from_dict(release)
try:
pendulum.parse(latest_modified_date)
except pendulum.parsing.Exceptions.ParserError:
raise AirflowException("Latest modified record date not valid: {latest_modified_date}")
api = DatasetAPI(bq_project_id=release.cloud_workspace.project_id, bq_dataset_id=api_bq_dataset_id)
now = pendulum.now()
dataset_release = DatasetRelease(
dag_id=release.dag_id,
entity_id="orcid",
dag_run_id=release.run_id,
created=now,
modified=now,
changefile_start_date=release.start_date,
changefile_end_date=release.end_date,
extra={"latest_modified_record_date": latest_modified_date},
)
api.add_dataset_release(dataset_release)
[docs]def cleanup_workflow(release: dict) -> None:
"""Delete all files, folders and XComs associated with this release."""
release = OrcidRelease.from_dict(release)
cleanup(dag_id=release.dag_id, workflow_folder=release.workflow_folder)
[docs]def aws_orcid_key(conn_id: str) -> Tuple[str, str]:
"""Return API login and password"""
connection = BaseHook.get_connection(conn_id)
return connection.login, connection.password
[docs]def create_orcid_batch_manifest(
orcid_batch: OrcidBatch, reference_date: pendulum.DateTime, bucket: str, bucket_prefix: str = None
) -> None:
"""Create a manifest (.csv) for each orcid batch containing blob names and modification dates of all the
modified files in the bucket's orcid directory. Only blobs modified after the reference date are included.
:param orcid_batch: The OrcidBatch instance for this orcid directory
:param reference_date: The date to use as a reference for the manifest
:param bucket: The name of the bucket
:param bucket_prefix: The prefix to use when listing blobs in the bucket. i.e. where the orcid directories are located
:return: The path to the date manifest, the path to the uri manifest
"""
prefix = f"{bucket_prefix}/{orcid_batch.batch_str}/" if bucket_prefix else f"{orcid_batch.batch_str}/"
logging.info(f"Creating manifests for {orcid_batch.batch_str}")
blobs = gcs_list_blobs(bucket, prefix=prefix)
manifest = []
for blob in blobs:
if pendulum.instance(blob.updated) > reference_date:
manifest.append(
{
MANIFEST_HEADER[0]: blob.bucket.name,
MANIFEST_HEADER[1]: blob.name,
MANIFEST_HEADER[2]: blob.updated,
}
)
with open(orcid_batch.manifest_file, "w", newline="") as csvfile:
writer = csv.DictWriter(csvfile, fieldnames=MANIFEST_HEADER)
writer.writeheader()
writer.writerows(manifest)
logging.info(f"Manifest saved to {orcid_batch.manifest_file}")
[docs]def latest_modified_record_date(release: dict) -> str:
"""Reads the manifest file and finds the most recent date of modification for the records
:param manifest_file_path: the path to the manifest file
:return: the most recent date of modification for the records
"""
release = OrcidRelease.from_dict(release)
with open(release.master_manifest_file, "r") as f:
reader = csv.DictReader(f)
modified_dates = sorted([pendulum.parse(row["updated"]) for row in reader])
latest_modified_record_date = datetime_normalise(modified_dates[-1])
logging.info(f"Latest modified record date: {latest_modified_record_date}")
return latest_modified_record_date