Source code for disco.cli

import sys
import os
os.environ["QT_QPA_PLATFORM"] = "offscreen"
import re
import warnings
import argparse
import csv
import numpy as np
import torch   
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from astropy.io import fits
from astropy.wcs import WCS
from astropy.wcs import FITSFixedWarning
from astropy.wcs.utils import pixel_to_skycoord, skycoord_to_pixel
from astropy.coordinates import SkyCoord
from astropy.time import Time
import astropy.units as u
from disco.core.cnn_inference import DiscoNet, predict_with_cnn
from disco.core.optimization import (geometric_loss, auto_tune_geometry_hybrid, estimate_geometry_errors, refine_center_geometry)
from disco.core.fits_utils import (
    get_alma_beam, deconvolve_beams, make_gaussian_kernel_casa,
    find_center_robust, auto_detect_parameters, extract_profile,
    save_debug_deproj_center, measure_rout_deproj, refine_center_local,
    deg_to_sex, pixel_to_icrs, icrs_to_pixel, get_obs_epoch,
    query_gaia_proper_motion, apply_proper_motion_correction,
    _ASTROQUERY_AVAILABLE
)
warnings.filterwarnings("ignore", category=FITSFixedWarning)
from scipy.ndimage import (
    map_coordinates, gaussian_filter, gaussian_filter1d,
    zoom, center_of_mass, binary_fill_holes, label
)
from scipy.optimize import minimize, differential_evolution
from scipy.signal import fftconvolve
from tqdm import tqdm



[docs] def discover_groups(base_dir): groups = [] for root, _, files in os.walk(base_dir): fits_files = [os.path.join(root, f) for f in files if f.lower().endswith('.fits')] if not fits_files: continue parent_name = os.path.basename(os.path.dirname(root)) grouped = {} for fpath in fits_files: stem = os.path.splitext(os.path.basename(fpath))[0] parts = re.split(r'_?[Bb]and_?\d+', stem, maxsplit=1) prefix = parts[0].rstrip('_') if parts[0] else stem grouped.setdefault(prefix, []).append(fpath) for prefix, group_files in grouped.items(): group_name = f"{parent_name}_{prefix}" if parent_name else prefix output_dir = os.path.join(root, prefix) groups.append({ "name": group_name, "files": sorted(group_files), "output_dir": output_dir, }) return groups
[docs] def run_pipeline(files_to_process, group_name, output_dir, args, cnn_model): tqdm.write(f"\n{'='*60}") tqdm.write(f"[START] Group: {group_name} ({len(files_to_process)} file(s))") tqdm.write(f"{'='*60}") pbar = tqdm(total=5, desc=f"Pipeline [{group_name}]", leave=False, dynamic_ncols=True) pbar.set_postfix_str("Phase 1/5: Reading FITS") temp_data = [] max_auto_rout = 0.0 for filepath in files_to_process: filename = os.path.basename(filepath) try: with fits.open(filepath, memmap=True) as hdul: data = np.nan_to_num(np.squeeze(hdul[0].data).astype(np.float32)) header = hdul[0].header bunit = str(header.get('BUNIT', '')).strip().upper() if bunit == 'JY/BEAM': data *= 1000 header['BUNIT'] = 'mJy/beam' elif bunit == '' and np.max(data) < 5.0: data *= 1000 pixel_scale = abs(header.get('CDELT2', 0.03)) * 3600 cx, cy = find_center_robust(data, pixel_scale, header) auto_rmin, auto_rout, bmaj = auto_detect_parameters(data, header, pixel_scale, cx, cy) bmin = header.get('BMIN', 0) * 3600 max_auto_rout = max(max_auto_rout, auto_rout) temp_data.append({ "filename": filename, "data": data, "header": header, "pixel_scale": pixel_scale, "cx": cx, "cy": cy, "auto_rmin": auto_rmin, "auto_rout": auto_rout, "bmaj": bmaj, "bmin": bmin, "obs_epoch": get_obs_epoch(header) }) except Exception as e: tqdm.write(f"[ERROR] Failed to read {filename}: {str(e)}") if not temp_data: tqdm.write(f"[WARN] No readable FITS files for group '{group_name}'. Skipping.") pbar.close() return best_score = -1 best_img_idx = 0 for i, item in enumerate(temp_data): d = item["data"] edge = np.concatenate([d[:10, :].ravel(), d[-10:, :].ravel(), d[:, :10].ravel(), d[:, -10:].ravel()]) rms = np.nanstd(edge) snr = np.nanmax(d) / rms if rms > 0 else 0 item["snr"] = snr bmaj_as = item["bmaj"] bmin_as = item["bmin"] if item["bmin"] > 0 else bmaj_as beam_area = bmaj_as * bmin_as score = snr / (beam_area ** 1.5) if beam_area > 0 and snr > 3.0 else 0.0 if score > best_score: best_score = score best_img_idx = i best_item = temp_data[best_img_idx] geom_rout = max_auto_rout unified_rout = args.rout if args.rout is not None else max_auto_rout final_rmin = args.rmin if args.rmin > 0.0 else best_item["auto_rmin"] ref_ra, ref_dec = None, None pmra_gaia, pmdec_gaia, gaia_sep = None, None, None ra_vals, dec_vals, snr_weights = [], [], [] for item in temp_data: try: ra_i, dec_i = pixel_to_icrs(item["header"], item["cx"], item["cy"]) ra_vals.append(ra_i) dec_vals.append(dec_i) snr_weights.append(max(item.get("snr", 1.0), 0.01)) except Exception: pass if ra_vals: w = np.array(snr_weights, dtype=float) w /= w.sum() ra_rad = np.radians(ra_vals) ref_ra = float(np.degrees(np.arctan2( np.sum(np.sin(ra_rad) * w), np.sum(np.cos(ra_rad) * w) )) % 360.0) ref_dec = float(np.sum(np.array(dec_vals) * w)) pmra_gaia, pmdec_gaia, gaia_sep = query_gaia_proper_motion(ref_ra, ref_dec) best_epoch = best_item.get("obs_epoch", None) for item in temp_data: if item is best_item: continue try: apply_ra, apply_dec = ref_ra, ref_dec if pmra_gaia is not None and best_epoch is not None: item_epoch = item.get("obs_epoch", None) if item_epoch is not None: dt_yr = item_epoch.jyear - best_epoch.jyear apply_ra, apply_dec = apply_proper_motion_correction( ref_ra, ref_dec, pmra_gaia, pmdec_gaia, dt_yr ) px, py = icrs_to_pixel(item["header"], apply_ra, apply_dec) #px, py = refine_center_local(item["data"], item["header"], item["pixel_scale"], px, py) item["cx"], item["cy"] = px, py except Exception: pass pbar.update(1) pbar.set_postfix_str("Phase 2/5: Optimizing Geometry") cnn_i, cnn_p = None, None if args.incl is not None and args.pa is not None: master_incl, master_pa, err_incl, err_pa = args.incl, args.pa, 0.0, 0.0 else: if cnn_model: cnn_context_rad = geom_rout * 1.1 geom_rmin = max(final_rmin, 1.5 * best_item["bmaj"]) incl, pa, cnn_i, cnn_p, dx, dy = auto_tune_geometry_hybrid( best_item["data"], best_item["header"], best_item["pixel_scale"], best_item["cx"], best_item["cy"], cnn_model, cnn_context_rad, geom_rmin ) try: new_cx = best_item["cx"] + dx new_cy = best_item["cy"] + dy ref_ra, ref_dec = pixel_to_icrs(best_item["header"], new_cx, new_cy) best_item["cx"], best_item["cy"] = new_cx, new_cy best_epoch = best_item.get("obs_epoch", None) for item in temp_data: if item is best_item: continue try: apply_ra, apply_dec = ref_ra, ref_dec if pmra_gaia is not None and best_epoch is not None: item_epoch = item.get("obs_epoch", None) if item_epoch is not None: dt_yr = item_epoch.jyear - best_epoch.jyear apply_ra, apply_dec = apply_proper_motion_correction( ref_ra, ref_dec, pmra_gaia, pmdec_gaia, dt_yr ) px, py = icrs_to_pixel(item["header"], apply_ra, apply_dec) item["cx"], item["cy"] = px, py except Exception: item["cx"] += dx item["cy"] += dy except Exception: for item in temp_data: item["cx"] += dx item["cy"] += dy else: pad = 500 d_pad = np.pad(best_item["data"], pad, mode='constant', constant_values=0) crop_rad = int((geom_rout / best_item["pixel_scale"]) * 1.5) + 10 dc = d_pad[int(best_item["cy"]) + pad - crop_rad : int(best_item["cy"]) + pad + crop_rad, int(best_item["cx"]) + pad - crop_rad : int(best_item["cx"]) + pad + crop_rad] res = minimize( geometric_loss, x0=[30.0, 45.0, 0.0, 0.0], args=(dc, crop_rad, crop_rad, crop_rad, final_rmin / best_item["pixel_scale"], geom_rout / best_item["pixel_scale"], 150, 1), method='Nelder-Mead', options={'xatol': 0.05, 'fatol': 1e-5} ) incl, pa = res.x[0], res.x[1] % 180 master_incl, master_pa = float(incl), float(pa) pbar.update(1) pbar.set_postfix_str("Phase 3/5: Estimating Errors") if args.incl is None or args.pa is None: err_incl, err_pa = estimate_geometry_errors( best_item["data"], best_item["pixel_scale"], best_item["cx"], best_item["cy"], master_incl, master_pa, final_rmin, geom_rout ) if args.rout is None: rout_deproj = measure_rout_deproj( best_item["data"], best_item["header"], best_item["pixel_scale"], best_item["cx"], best_item["cy"], master_incl, master_pa, rmin=final_rmin ) ratio = max_auto_rout / (rout_deproj + 1e-6) incl_factor = np.clip((master_incl - 60.0) / 30.0, 0.0, 1.0) w_deproj = 0.60 - 0.25 * incl_factor w_heur = 1.0 - w_deproj if ratio > 1.5 and master_incl < 55.0: unified_rout = rout_deproj * 1.15 fusion_mode = "deproj×1.15" elif ratio > 1.5 and master_incl >= 55.0: unified_rout = 0.50 * rout_deproj * 1.15 + 0.50 * max_auto_rout fusion_mode = "mean (high-incl)" elif ratio < 0.8: unified_rout = max_auto_rout fusion_mode = "heuristic" else: unified_rout = w_deproj * rout_deproj + w_heur * max_auto_rout fusion_mode = "weighted mean" bmaj_ref = best_item["bmaj"] unified_rout = max(unified_rout, bmaj_ref * 1.5, 0.10) tqdm.write( f"[RESULT] Reference : {best_item['filename']}" f" | SNR: {best_item['snr']:.1f}" f" | Beam: {best_item['bmaj']:.3f}\"" ) tqdm.write( f" Geometry : i = {master_incl:.1f}° ± {err_incl:.1f}°" f" | PA = {master_pa:.1f}° ± {err_pa:.1f}°" ) tqdm.write( f" Extent : Rout = {unified_rout:.3f}\" | Rmin = {final_rmin:.3f}\"" + (f" [{fusion_mode}]" if args.rout is None else " [forced]") ) if cnn_model and cnn_i is not None: tqdm.write(f" CNN prior : i = {cnn_i:.1f}° | PA = {cnn_p:.1f}°") if ref_ra is not None: center_coord = SkyCoord(ra=ref_ra * u.deg, dec=ref_dec * u.deg, frame='icrs') tqdm.write( f" Center : RA = {center_coord.ra.to_string(unit=u.hour, sep=':', precision=3)}" f" | Dec = {center_coord.dec.to_string(sep=':', precision=2)} (ICRS)" ) if pmra_gaia is not None: tqdm.write( f" Gaia PM : pmRA = {pmra_gaia:.3f} mas/yr" f" | pmDec = {pmdec_gaia:.3f} mas/yr" f" | match = {gaia_sep:.2f}\"" ) elif _ASTROQUERY_AVAILABLE and ref_ra is not None: tqdm.write(f" Gaia PM : no match within 3\" — PM correction skipped") pbar.update(1) pbar.set_postfix_str("Phase 4/5: Extracting Profiles") if args.debug == 'on': debug_dir = os.path.join(output_dir, "debug_pipeline") os.makedirs(debug_dir, exist_ok=True) out_png = os.path.join(debug_dir, f"{group_name}_debug_center_rout.png") save_debug_deproj_center( best_item["data"], best_item["cx"], best_item["cy"], master_incl, master_pa, unified_rout, best_item["pixel_scale"], out_png, title=f"{group_name} | i={master_incl:.1f} PA={master_pa:.1f}" ) if args.homobeam == 'on': try: if args.beam is not None: t_bmaj, t_bmin, t_bpa = args.beam, args.beam, 0.0 else: max_bmaj = max(float(img["header"].get('BMAJ', 0)) for img in temp_data) * 3600.0 t_bmaj, t_bmin, t_bpa = max_bmaj * 1.01, max_bmaj * 1.01, 0.0 except Exception: t_bmaj, t_bmin, t_bpa = 0.0, 0.0, 0.0 if t_bmaj > 0: for img in temp_data: i_bmaj = img["header"].get('BMAJ', 0) * 3600.0 i_bmin = img["header"].get('BMIN', 0) * 3600.0 i_bpa = img["header"].get('BPA', 0) if i_bmaj == 0 or i_bmin == 0 or (np.isclose(i_bmaj, t_bmaj) and np.isclose(i_bmin, t_bmin)): continue bmaj_c, bmin_c, pa_c = deconvolve_beams(t_bmaj, t_bmin, t_bpa, i_bmaj, i_bmin, i_bpa) if bmaj_c is not None: kernel = make_gaussian_kernel_casa(bmaj_c, bmin_c, pa_c, img["pixel_scale"]) img["data"] = fftconvolve(img["data"], kernel, mode='same') scale_factor = (t_bmaj * t_bmin) / (i_bmaj * i_bmin) img["data"] *= scale_factor img["header"]['BMAJ'] = t_bmaj / 3600.0 img["header"]['BMIN'] = t_bmin / 3600.0 img["header"]['BPA'] = t_bpa img["bmaj"] = t_bmaj img["bmin"] = t_bmin plt.style.use('default') fig, ax = plt.subplots(figsize=(10, 6), dpi=150) ax.set_facecolor('white') limit_plot = unified_rout def band_snr(item): d = item["data"] edge = np.concatenate([d[:10, :].ravel(), d[-10:, :].ravel(), d[:, :10].ravel(), d[:, -10:].ravel()]) rms = np.nanstd(edge) return np.nanmax(d) / rms if rms > 0 else 0.0 snr_map = {img["filename"]: band_snr(img) for img in temp_data} csv_data = {} max_pts = 0 for img in sorted(temp_data, key=lambda x: snr_map[x["filename"]], reverse=True): band_bmaj = img["bmaj"] lbl = img["filename"] m = re.search(r'(Band_\d+)', lbl, re.IGNORECASE) if m: lbl = m.group(1).replace('_', ' ') r_arcsec, tb_prof, tb_err = extract_profile( img["data"], img["header"], master_incl, master_pa, img["pixel_scale"], img["cx"], img["cy"], limit_arcsec=limit_plot ) limit_idx = min(np.searchsorted(r_arcsec, limit_plot), len(r_arcsec)) r_plot = r_arcsec[:limit_idx] y_plot = tb_prof[:limit_idx] err_plot = tb_err[:limit_idx] max_val = np.nanmax(y_plot) y_norm = y_plot / max_val if max_val > 0 else y_plot e_norm = err_plot / max_val if max_val > 0 else err_plot mx2 = np.nanmax(y_norm) if mx2 > 0: y_norm = y_norm / mx2 e_norm = e_norm / mx2 y_norm_clip = np.clip(y_norm, -0.05, 1.05) ax.plot(r_plot, y_norm_clip, lw=2.5, label=lbl) lbl_clean = lbl.replace(' ', '_') csv_data[lbl_clean] = { 'filename': img["filename"], 'snr': snr_map[img["filename"]], 'cx': img["cx"], 'cy': img["cy"], 'bmaj': img["bmaj"], 'bmin': img.get("bmin", 0.0), 'max_flux': max_val, 'r': r_plot, 'i_raw': y_plot, 'e_raw': err_plot, 'i_norm': y_norm_clip, 'e_norm': e_norm } max_pts = max(max_pts, len(r_plot)) pbar.update(1) pbar.set_postfix_str("Phase 5/5: Saving Results") ax.set_ylim(-0.05, 1.05) ax.set_xlim(0, limit_plot) ax.set_xlabel("r / arcsec", fontsize=12) ax.set_ylabel("Normalized Intensity", fontsize=12) ax.tick_params(direction='in', labelsize=10) ax.grid(True, which='both', color='gray', alpha=0.3, linestyle='--') ax.set_title( f"Radial Profiles — {group_name}\n" f"i={master_incl:.1f}°±{err_incl:.1f}° PA={master_pa:.1f}°±{err_pa:.1f}°", fontweight='bold', fontsize=13 ) ax.legend(fontsize=10) os.makedirs(output_dir, exist_ok=True) output_png = os.path.join(output_dir, f"RP_{group_name}.PNG") plt.savefig(output_png, format='png', bbox_inches='tight', facecolor='white') plt.close(fig) if args.csv == 'on': try: ellipse_smaj = unified_rout ellipse_smin = unified_rout * np.cos(np.radians(master_incl)) with open(os.path.join(output_dir, f"RP_{group_name}_global.csv"), mode='w', newline='') as f: writer = csv.writer(f) writer.writerow(["parameter", "value", "uncertainty"]) writer.writerow(["Rout_arcsec", f"{unified_rout:.4f}", ""]) writer.writerow(["Rmin_arcsec", f"{final_rmin:.4f}", ""]) writer.writerow(["Inclination_deg", f"{master_incl:.3f}", f"{err_incl:.3f}"]) writer.writerow(["PA_deg", f"{master_pa:.3f}", f"{err_pa:.3f}"]) writer.writerow(["Ellipse_smaj_arcsec", f"{ellipse_smaj:.4f}", ""]) writer.writerow(["Ellipse_smin_arcsec", f"{ellipse_smin:.4f}", ""]) if ref_ra is not None: writer.writerow(["Center_RA_deg", f"{ref_ra:.8f}", ""]) writer.writerow(["Center_Dec_deg", f"{ref_dec:.8f}", ""]) if pmra_gaia is not None: writer.writerow(["Gaia_pmRA_masyr", f"{pmra_gaia:.4f}", ""]) writer.writerow(["Gaia_pmDec_masyr", f"{pmdec_gaia:.4f}", ""]) writer.writerow(["Gaia_match_arcsec", f"{gaia_sep:.4f}", ""]) with open(os.path.join(output_dir, f"RP_{group_name}_bands.csv"), mode='w', newline='') as f: writer = csv.writer(f) writer.writerow(["filename", "snr", "cx_pix", "cy_pix", "bmaj_arcsec", "bmin_arcsec", "peak_flux_mJyBeam"]) for d in csv_data.values(): writer.writerow([ d['filename'], f"{d['snr']:.1f}", f"{d['cx']:.2f}", f"{d['cy']:.2f}", f"{d['bmaj']:.4f}", f"{d['bmin']:.4f}", f"{d['max_flux']:.6e}" ]) with open(os.path.join(output_dir, f"RP_{group_name}_profile.csv"), mode='w', newline='') as f: writer = csv.writer(f) header_row = [] for lbl_clean in csv_data.keys(): header_row.extend([ f"R_{lbl_clean}_arcsec", f"Flux_{lbl_clean}_mJy", f"FluxNorm_{lbl_clean}", f"FluxNormErr_{lbl_clean}" ]) writer.writerow(header_row) for i in range(max_pts): row = [] for lbl_clean in csv_data.keys(): d = csv_data[lbl_clean] if i < len(d['r']): row.extend([ f"{d['r'][i]:.6f}", f"{d['i_raw'][i]:.6e}", f"{d['i_norm'][i]:.6f}", f"{d['e_norm'][i]:.6f}" ]) else: row.extend(["", "", "", ""]) writer.writerow(row) except Exception as e: tqdm.write(f"[ERROR] Failed to save CSV: {e}") pbar.update(1) pbar.close()
[docs] def main(): parser = argparse.ArgumentParser(description="DISCO Automated Pipeline") parser.add_argument("identifier", nargs="*", help="Object prefix(es) or FITS file path(s)") parser.add_argument("--rout", type=float, default=None, help="Force Rout (arcsec)") parser.add_argument("--rmin", type=float, default=0.0, help="Force Rmin (arcsec)") parser.add_argument("--incl", type=float, default=None, help="Force inclination (deg)") parser.add_argument("--pa", type=float, default=None, help="Force PA (deg)") parser.add_argument("--beam", type=float, default=None, help="Force target beam resolution (arcsec)") parser.add_argument("--homobeam", type=str, default="on", choices=["on", "off"], help="Toggle beam homogenization") parser.add_argument("--csv", type=str, default="off", choices=["on", "off"], help="Export CSV data") parser.add_argument("--debug", type=str, default="off", choices=["on", "off"], help="Save debug deprojected image") args = parser.parse_args() print("\n" + "-"*60) print("WARNING: DISCO will now scan the current directory and all") print("subdirectories to search for and process FITS files.") print(f"Current directory: {os.getcwd()}") print("-"*60) user_response = input("\nAre you sure you want to continue? [y/N]: ").strip().lower() if user_response != 'y' and user_response != 'yes': print("Operation cancelled by user. Exiting...") sys.exit(0) cnn_model = None BASE_DIR = os.path.dirname(os.path.abspath(__file__)) model_path = os.path.join(BASE_DIR, "models", "disco_model_stable.pth") if os.path.exists(model_path): try: ckpt = torch.load(model_path, map_location='cpu', weights_only=True) cnn_model = DiscoNet(n_out=5) state = ckpt["model_state"] if isinstance(ckpt, dict) else ckpt cnn_model.load_state_dict(state) cnn_model.eval() print("[INFO] CNN model loaded.") except Exception as e: print(f"[WARN] CNN model load failed: {e}. Falling back to analytical geometry.") else: print("[WARN] Model file not found. Falling back to analytical geometry.") base_dir = os.getcwd() all_groups = discover_groups(base_dir) if not all_groups: print("[ERROR] No FITS files found in the current directory tree.") sys.exit(1) if args.identifier: groups = [] clean_ids = [ident.strip(',') for ident in args.identifier] for g in all_groups: path_parts = g['output_dir'].replace('\\', '/').split('/') if any(ident in path_parts or ident in g['output_dir'] for ident in clean_ids): groups.append(g) else: matched_files = [ f for f in g['files'] if any(ident in os.path.basename(f) for ident in clean_ids) ] if matched_files: groups.append({ "name": g['name'], "files": matched_files, "output_dir": g['output_dir'] }) if not groups: print(f"[ERROR] No FITS files match the provided identifiers: {clean_ids}") sys.exit(1) else: groups = all_groups print(f"[INFO] Found {len(groups)} group(s) to process.\n") for group in tqdm(groups, desc="Total Groups", unit="group", dynamic_ncols=True): try: run_pipeline(group["files"], group["name"], group["output_dir"], args, cnn_model) except Exception as e: tqdm.write(f"[ERROR] Processing failed for group '{group['name']}': {e}") print("\n[INFO] Pipeline execution completed.")
if __name__ == "__main__": main()