Source code for academic_observatory_workflows.openalex_telescope.generate_schema

# Copyright 2022-2024 Curtin University
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import argparse
import json
import logging
import os
from collections import OrderedDict
from concurrent.futures import as_completed, ProcessPoolExecutor
from pathlib import Path

from academic_observatory_workflows.openalex_telescope.tasks import (
    flatten_schema,
    merge_schema_maps,
    transform_file,
)
from observatory_platform.files import list_files


[docs]def sort_schema(input_file: Path): def sort_schema_func(schema): # Sort schema entries by name and sort the fields of each entry by key_order key_order = ["name", "type", "mode", "description", "fields"] sorted_schema = [ {k: field[k] for k in key_order if k in field} for field in sorted(schema, key=lambda x: x["name"]) ] # Sort the fields recursively for field in sorted_schema: if field.get("type") == "RECORD" and "fields" in field: field["fields"] = sort_schema_func(field["fields"]) return sorted_schema # Load the JSON schema from a string with open(input_file, mode="r") as f: data = json.load(f) # Sort the schema sorted_json_schema = sort_schema_func(data) # Save the schema with open(input_file, mode="w") as f: json.dump(sorted_json_schema, f, indent=2)
[docs]def generate_schema(entity_name: str, input_folder: Path, output_folder: Path, max_workers: int): merged_schema_map = OrderedDict() with ProcessPoolExecutor(max_workers=max_workers) as executor: futures = [] for input_path in list_files(str(input_folder), r"^part_\d{3}\.gz$"): output_path = str(output_folder / Path(input_path).relative_to(input_folder)) futures.append(executor.submit(transform_file, input_path, output_path)) for future in as_completed(futures): input_path, schema_map, schema_error = future.result() if schema_error: logging.info(f"Error generating schema for file {input_path}: {schema_error}") # Merge the schemas from each process. Each data file could have more fields than others. merged_schema_map = merge_schema_maps(to_add=schema_map, old=merged_schema_map) # Flatten schema from nested OrderedDicts to a regular Bigquery schema. merged_schema = flatten_schema(schema_map=merged_schema_map) # Save schema to file generated_schema_path = os.path.join(output_folder, f"{entity_name}.json") with open(generated_schema_path, mode="w") as f_out: json.dump(merged_schema, f_out, indent=2) sort_schema(Path(generated_schema_path))
[docs]def check_directory(path): """Check if the provided path is a valid directory.""" if not Path(path).is_dir(): raise argparse.ArgumentTypeError(f"The directory {path} does not exist.") return Path(path)
if __name__ == "__main__": """Simple command line tool to generate a BigQuery schema for an OpenAlex entity"""
[docs] parser = argparse.ArgumentParser(description="Process OpenAlex entities.")
# Required arguments parser.add_argument("entity_name", type=str, help="The name of the entity") parser.add_argument("input_folder", type=check_directory, help="The input folder path") parser.add_argument("output_folder", type=check_directory, help="The output folder path") # Optional argument with default value from os.cpu_count() parser.add_argument( "--max_workers", type=int, default=os.cpu_count(), help="The maximum number of workers (default: number of CPUs)", ) # Parse the arguments args = parser.parse_args() generate_schema(args.entity_name, args.input_folder, args.output_folder, args.max_workers)