import os
import shutil
import io
import base64
import webbrowser
from typing import Optional
import threading
import time
import uvicorn
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from matplotlib.patches import Ellipse
from matplotlib.ticker import FixedLocator
from fastapi import FastAPI, HTTPException, UploadFile, File
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from fastapi.responses import Response, FileResponse
from pydantic import BaseModel
import socket
from astropy.io import fits
from astropy.wcs import WCS
import astropy.units as u
from scipy.ndimage import map_coordinates
from scipy.optimize import minimize, curve_fit
from astropy.visualization import ImageNormalize, AsinhStretch, LogStretch, LinearStretch, SqrtStretch
from disco.core.optimization import geometric_loss
try:
from astroquery.simbad import Simbad
ASTROQUERY_AVAILABLE = True
except ImportError:
ASTROQUERY_AVAILABLE = False
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
STATIC_DIR = os.path.join(BASE_DIR, "static")
UPLOAD_DIR = os.path.join(os.getcwd(), ".disco_uploads")
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origin_regex=r"https?://(localhost|127\.0\.0\.1)(:\d+)?",
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class GlobalState:
data = None
header = None
filename = None
results = {}
extents = {}
profile_data = None
state = GlobalState()
def wipe_session_logic():
state.data = None
state.header = None
state.filename = None
state.results = {}
state.extents = {}
state.profile_data = None
if os.path.exists(UPLOAD_DIR):
for filename in os.listdir(UPLOAD_DIR):
file_path = os.path.join(UPLOAD_DIR, filename)
try:
if os.path.isfile(file_path) or os.path.islink(file_path):
os.unlink(file_path)
elif os.path.isdir(file_path):
shutil.rmtree(file_path)
except Exception:
pass
else:
os.makedirs(UPLOAD_DIR, exist_ok=True)
wipe_session_logic()
@app.on_event("shutdown")
def cleanup_on_shutdown():
print("\n[INFO] Cleaning directory and temporary files...")
wipe_session_logic()
@app.post("/reset_session")
def reset_session_endpoint():
wipe_session_logic()
return {"status": "Session cleared"}
[docs]
class PlotParams(BaseModel):
type: str
cmap: str = 'magma'
stretch: str = 'asinh'
vmax_percentile: Optional[float] = None
vmin: Optional[float] = None
vmax: Optional[float] = None
contours: bool = False
contour_levels: int = 5
show_beam: bool = False
show_grid: bool = False
show_axes: bool = True
show_colorbar: bool = True
title: Optional[str] = ""
dpi: int = 150
[docs]
class PipelineParams(BaseModel):
cx: float
cy: float
pa: float
incl: float
rout: float
fit_rmin: float = 0.0
fit_rmax: float = 0.0
[docs]
class OptimizeParams(BaseModel):
cx: float
cy: float
pa: float
incl: float
rout: float
fit_rmin: float = 0.0
fit_rmax: float = 0.0
[docs]
class LoadLocalParams(BaseModel):
filename: str
def array_to_base64(data_array, cmap='magma', stretch_val=0.03):
mx = np.nanmax(data_array)
if np.isnan(mx) or mx <= 0:
mx = 1.0
norm = ImageNormalize(vmin=0.0, vmax=mx, stretch=AsinhStretch(stretch_val))
fig = plt.figure(figsize=(6, 6), dpi=150)
ax = fig.add_axes([0, 0, 1, 1])
ax.axis('off')
ax.imshow(data_array, origin='lower', cmap=cmap, norm=norm, interpolation='nearest', aspect='auto')
buf = io.BytesIO()
plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
plt.close(fig)
buf.seek(0)
return base64.b64encode(buf.read()).decode('utf-8')
def gaussian(x, a, x0, sigma, c):
return a * np.exp(-(x - x0) ** 2 / (2 * sigma ** 2)) + c
@app.post("/upload")
async def upload_file(file: UploadFile = File(...)):
try:
os.makedirs(UPLOAD_DIR, exist_ok=True)
file_location = os.path.join(UPLOAD_DIR, file.filename)
with open(file_location, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
with fits.open(file_location) as hdul:
state.data = np.nan_to_num(np.squeeze(hdul[0].data))
state.header = hdul[0].header
state.filename = file_location
state.results = {}
if np.max(state.data) < 0.1:
state.data *= 1000
cdelt = state.header.get('CDELT2', 0.03)
pixel_scale = abs(cdelt) * 3600
return {"filename": file.filename, "status": "loaded", "shape": state.data.shape, "pixel_scale": pixel_scale}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/load_local")
def load_local(params: LoadLocalParams):
clean_name = os.path.basename(params.filename)
file_path = os.path.join(UPLOAD_DIR, clean_name)
if not os.path.exists(file_path):
raise HTTPException(status_code=404, detail="File not found")
try:
with fits.open(file_path) as hdul:
state.data = np.nan_to_num(np.squeeze(hdul[0].data))
state.header = hdul[0].header
state.filename = file_path
state.results = {}
if np.max(state.data) < 0.1:
state.data *= 1000
cdelt = state.header.get('CDELT2', 0.03)
pixel_scale = abs(cdelt) * 3600
return {"status": "loaded", "filename": clean_name, "shape": state.data.shape, "pixel_scale": pixel_scale}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/preview")
def get_preview():
if state.data is None:
raise HTTPException(status_code=404, detail="No data")
img_b64 = array_to_base64(state.data, cmap='inferno', stretch_val=0.02)
return {"image": f"data:image/png;base64,{img_b64}"}
@app.get("/get_header")
def get_header():
if state.header is None:
return {"header": []}
header_list = []
for key, value in state.header.items():
if key == 'COMMENT' or key == 'HISTORY':
continue
header_list.append({"key": key, "value": str(value), "comment": state.header.comments[key]})
return {"header": header_list}
@app.post("/optimize_geometry")
def optimize_geometry(params: OptimizeParams):
if state.data is None:
raise HTTPException(status_code=400, detail="No Data")
data = state.data
ny, nx = data.shape
eff_cy = ny - params.cy
eff_cx = params.cx
pad = 1000
d_pad = np.pad(data, pad, mode='constant', constant_values=0)
real_cy_int = int(eff_cy) + pad
real_cx_int = int(eff_cx) + pad
offset_y = (eff_cy + pad) - real_cy_int
offset_x = (eff_cx + pad) - real_cx_int
pixel_scale = abs(state.header.get('CDELT2', 0.03)) * 3600
search_rad = max(params.rout, params.fit_rmax)
search_rad_pix = int(search_rad / pixel_scale)
crop_rad = int(search_rad_pix * 1.2) + 10
dc = d_pad[real_cy_int - crop_rad: real_cy_int + crop_rad, real_cx_int - crop_rad: real_cx_int + crop_rad]
local_c_y = crop_rad + offset_y
local_c_x = crop_rad + offset_x
rmin_pix = params.fit_rmin / pixel_scale
rmax_pix = params.fit_rmax / pixel_scale
best_guess = [params.incl, params.pa, 0.0, 0.0]
test_incls = [10, 30, 50, 70]
test_pas = range(0, 180, 30)
min_loss = geometric_loss(best_guess, dc, local_c_x, local_c_y, crop_rad, rmin_pix, rmax_pix, dim=100, order=1)
for ti in test_incls:
for tp in test_pas:
l = geometric_loss([ti, tp, 0.0, 0.0], dc, local_c_x, local_c_y, crop_rad, rmin_pix, rmax_pix, dim=100, order=1)
if l < min_loss:
min_loss = l
best_guess = [ti, tp, 0.0, 0.0]
res = minimize(
geometric_loss,
best_guess,
args=(dc, local_c_x, local_c_y, crop_rad, rmin_pix, rmax_pix, 400, 3),
method='Nelder-Mead',
bounds=[(0, 85), (0, 180), (-10, 10), (-10, 10)],
tol=0.01
)
best_incl, best_pa, best_dx, best_dy = res.x
best_pa = best_pa % 180
if best_pa < 0:
best_pa += 180
return {"optimized_incl": float(best_incl), "optimized_pa": float(best_pa)}
@app.post("/run_pipeline")
def run_pipeline(params: PipelineParams):
if state.data is None:
raise HTTPException(status_code=400, detail="No FITS data loaded.")
data = state.data
ny, nx = data.shape
eff_cy = ny - params.cy
eff_cx = params.cx
pa_rad = np.radians(params.pa)
incl_rad = np.radians(params.incl)
pixel_scale = abs(state.header.get('CDELT2', 0.03)) * 3600
crop_size = 2000
crop_rad = crop_size // 2
pad = crop_rad
d_pad = np.pad(data, pad, mode='constant', constant_values=0)
y_start_int = int(eff_cy) + pad - crop_rad
x_start_int = int(eff_cx) + pad - crop_rad
local_cy = (eff_cy + pad) - y_start_int
local_cx = (eff_cx + pad) - x_start_int
dc = d_pad[y_start_int: y_start_int + crop_size, x_start_int: x_start_int + crop_size]
if dc.shape != (crop_size, crop_size):
temp = np.zeros((crop_size, crop_size))
h, w = dc.shape
temp[0:h, 0:w] = dc
dc = temp
beam_info = None
try:
if 'BMAJ' in state.header:
beam_info = {
"major": state.header['BMAJ'] * 3600,
"minor": state.header['BMIN'] * 3600,
"pa": state.header.get('BPA', 0.0)
}
except Exception:
pass
dim = 1000
x = np.arange(dim) - 500
X, Y = np.meshgrid(x, x)
Xc = X * np.cos(incl_rad)
Xrot = np.cos(pa_rad) * Xc + np.sin(pa_rad) * Y
Yrot = -np.sin(pa_rad) * Xc + np.cos(pa_rad) * Y
coords_deproj = [Yrot + local_cy, -Xrot + local_cx]
deproj = map_coordinates(dc, coords_deproj, order=3, cval=0.0)
deproj = np.fliplr(deproj)
max_radius_pix = np.hypot(500, 500)
n_steps = int(max_radius_pix)
r_full = np.linspace(0, max_radius_pix, n_steps)
th = np.linspace(-180, 180, 361)
R, TH = np.meshgrid(r_full, th)
Xd = R * np.cos(np.radians(TH))
Yd = R * np.sin(np.radians(TH))
coords_polar = [Yd + 500, Xd + 500]
polar_full = map_coordinates(deproj, coords_polar, order=1)
polar_full = np.flipud(polar_full)
prof_full = np.nanmean(polar_full, axis=0)
d_map = np.sqrt(X ** 2 + Y ** 2)
mod = np.interp(d_map.flatten(), r_full, prof_full).reshape(dim, dim)
resi = deproj - mod
rout_pix = params.rout / pixel_scale
limit_idx = np.searchsorted(r_full, rout_pix)
limit_idx = min(limit_idx, n_steps)
polar_display = polar_full[:, :limit_idx]
prof_display = prof_full[:limit_idx]
r_display = r_full[:limit_idx]
try:
bmaj = state.header.get('BMAJ', 0) * 3600
bmin = state.header.get('BMIN', 0) * 3600
restfrq = state.header.get('RESTFRQ', 230e9)
if bmaj > 0 and bmin > 0:
beam_sr = (np.pi * bmaj * bmin / (4 * np.log(2))) / 206265 ** 2
kB = 1.38e-16
c = 3e10
tb_prof = (c ** 2 * 1e-23 * prof_display / 1000) / (2 * kB * restfrq ** 2 * beam_sr)
else:
tb_prof = prof_display
except Exception:
tb_prof = prof_display
r_arcsec = r_display * pixel_scale
start = crop_rad - 500
end = crop_rad + 500
dc_view = dc[start:end, start:end]
fov_arcsec = 1000 * pixel_scale
limit_arcsec = fov_arcsec / 2
ext_cartesian = [limit_arcsec, -limit_arcsec, -limit_arcsec, limit_arcsec]
ext_polar = [0, params.rout, -180, 180]
state.results = {'data': dc_view, 'deproj': deproj, 'polar': polar_display, 'model': mod, 'residuals': resi}
state.extents = {'data': ext_cartesian, 'deproj': ext_cartesian, 'model': ext_cartesian, 'residuals': ext_cartesian, 'polar': ext_polar}
prof_jy = prof_display / 1000.0
state.profile_data = {'radius': r_arcsec.tolist(), 'tb': tb_prof.tolist(), 'raw': prof_jy.tolist()}
fit_stats = None
if params.fit_rmax > params.fit_rmin and (params.fit_rmax - params.fit_rmin) > 0.05:
try:
mask = (r_arcsec >= params.fit_rmin) & (r_arcsec <= params.fit_rmax)
x_region = r_arcsec[mask]
y_region = tb_prof[mask]
if len(y_region) > 5:
idx_max = np.argmax(y_region)
amp_guess = y_region[idx_max]
mean_guess = x_region[idx_max]
if amp_guess > 0:
sigma_guess = (params.fit_rmax - params.fit_rmin) / 4
p0 = [amp_guess, mean_guess, sigma_guess, 0.0]
popt, _ = curve_fit(gaussian, x_region, y_region, p0=p0, maxfev=2000)
fwhm = 2.355 * abs(popt[2])
fit_stats = {
"peak_radius": float(popt[1]),
"fwhm": float(fwhm),
"peak_intensity": float(popt[0])
}
except Exception:
fit_stats = None
fov_cartesian = 1000 * pixel_scale
fov_polar = params.rout
return {
"images": {
"data": f"data:image/png;base64,{array_to_base64(dc_view, cmap='inferno')}",
"deproj": f"data:image/png;base64,{array_to_base64(deproj, cmap='inferno')}",
"polar": f"data:image/png;base64,{array_to_base64(polar_display, cmap='inferno', stretch_val=0.1)}",
"model": f"data:image/png;base64,{array_to_base64(mod, cmap='inferno')}",
"residuals": f"data:image/png;base64,{array_to_base64(resi, cmap='magma', stretch_val=0.9)}"
},
"profile": {"radius": r_arcsec.tolist(), "intensity": tb_prof.tolist(), "raw_intensity": prof_jy.tolist()},
"geometry": {"fov_cartesian": fov_cartesian, "fov_polar": fov_polar, "beam": beam_info, "pixel_scale": pixel_scale},
"fit": fit_stats
}
@app.post("/render_plot")
def render_plot(params: PlotParams):
plt.style.use('default')
if params.type in ['polar', 'profile']:
fig = plt.figure(figsize=(12, 5), dpi=params.dpi)
else:
fig = plt.figure(figsize=(10, 10), dpi=params.dpi)
if params.type == 'profile':
if state.profile_data is None:
plt.close(fig)
raise HTTPException(status_code=400, detail="Profile data not available.")
ax = fig.add_subplot(111)
ax.set_facecolor('white')
x = np.array(state.profile_data['radius'])
y = np.array(state.profile_data['tb'])
safe_y = np.where((y > 0) & np.isfinite(y), y, 1e-10)
ax.plot(x, safe_y, 'k', lw=1.5)
ax.set_yscale('log')
vmin = params.vmin
vmax = params.vmax
if vmin is None:
vmin = np.min(safe_y) if len(safe_y) > 0 else 0.1
if vmax is None:
vmax = np.max(safe_y) if len(safe_y) > 0 else 100
ax.set_xlim(0, x[-1] if len(x) > 0 else 1)
ax.set_ylim(vmin, vmax)
ax.set_xlabel("Radius [arcsec]", fontsize=12)
ax.set_ylabel("Tb [K]", fontsize=12)
ax.tick_params(direction='in', labelsize=10)
if params.show_grid:
ax.grid(True, which='both', color='gray', alpha=0.3, linestyle='--')
title_txt = params.title if params.title else "Radial Profile"
ax.set_title(title_txt, fontweight='bold', fontsize=14)
buf = io.BytesIO()
plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0.1, facecolor='white')
plt.close(fig)
buf.seek(0)
img_b64 = base64.b64encode(buf.read()).decode('utf-8')
return {
"image": f"data:image/png;base64,{img_b64}",
"stats": {
"min": float(np.min(safe_y)),
"max": float(np.max(safe_y)),
"vmin_used": float(vmin),
"vmax_used": float(vmax),
"cmap_used": params.cmap
}
}
image_data = None
if state.results and params.type in state.results:
image_data = state.results[params.type]
elif params.type == 'data' and state.data is not None:
image_data = state.data
if image_data is None:
plt.close(fig)
raise HTTPException(status_code=400, detail=f"Data for '{params.type}' not available.")
if params.show_axes:
ax = fig.add_subplot(111)
ax.set_facecolor('white')
else:
ax = fig.add_axes([0, 0, 1, 1])
ax.axis('off')
vmin = params.vmin
vmax = params.vmax
if vmin is None:
if params.type == 'residuals':
limit = np.percentile(np.abs(image_data), 100)
vmin = -limit
vmax = limit if vmax is None else vmax
else:
vmin = 0.0
if vmax is None:
if params.type != 'residuals':
vmax = np.percentile(image_data, 100)
if vmax <= vmin:
vmax = vmin + 1e-10
if params.stretch == 'log':
norm = ImageNormalize(vmin=vmin, vmax=vmax, stretch=LogStretch())
elif params.stretch == 'linear':
norm = ImageNormalize(vmin=vmin, vmax=vmax, stretch=LinearStretch())
elif params.stretch == 'sqrt':
norm = ImageNormalize(vmin=vmin, vmax=vmax, stretch=SqrtStretch())
else:
norm = ImageNormalize(vmin=vmin, vmax=vmax, stretch=AsinhStretch(0.02))
aspect = 'auto' if params.type == 'polar' else 'equal'
extent = state.extents.get(params.type, None)
im = ax.imshow(image_data, origin='lower', cmap=params.cmap, norm=norm, aspect=aspect, extent=extent)
if params.show_axes:
if params.title:
ax.set_title(params.title, fontweight='bold', fontsize=14)
else:
titles = {'data': "Input Data", 'deproj': "Deprojected View", 'polar': "Polar Map", 'model': "Azimuthal Model", 'residuals': "Residual Map"}
ax.set_title(titles.get(params.type, params.type.capitalize()), fontweight='bold', fontsize=14)
ax.tick_params(direction='in', labelsize=10, color='black')
if params.type == 'polar':
ax.set_xlabel("Radius [arcsec]", fontsize=12)
ax.set_ylabel("Azimuth [deg]", fontsize=12)
else:
ax.set_xlabel("RA Offset [arcsec]", fontsize=12)
ax.set_ylabel("Dec Offset [arcsec]", fontsize=12)
if params.show_grid:
ax.grid(True, color='white', alpha=0.3, linestyle='--')
if params.show_colorbar:
cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
if vmax <= 10 and params.stretch == 'asinh':
cbar.locator = FixedLocator([0.0, 0.2, 0.5, 1.0, 2.0, 4.0])
cbar.update_ticks()
cbar.set_label('Intensity', fontsize=10)
cbar.ax.tick_params(labelsize=9)
if params.contours:
try:
ax.contour(image_data, levels=params.contour_levels, colors='white', alpha=0.5, linewidths=0.8, extent=extent)
except Exception:
pass
if params.show_beam and params.type != 'polar' and params.show_axes:
try:
if 'BMAJ' in state.header:
bmaj = state.header['BMAJ'] * 3600
bmin = state.header['BMIN'] * 3600
bpa = state.header.get('BPA', 0.0)
if extent:
width_phys = abs(extent[1] - extent[0])
height_phys = abs(extent[3] - extent[2])
bx = extent[0] + width_phys * 0.1
by = extent[2] + height_phys * 0.1
beam_patch = Ellipse((bx, by), width=bmin, height=bmaj, angle=bpa, facecolor='white', edgecolor='black', zorder=20)
ax.add_patch(beam_patch)
except Exception:
pass
buf = io.BytesIO()
is_transparent = not params.show_axes
plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0.1 if params.show_axes else 0, transparent=is_transparent, facecolor='white')
plt.close(fig)
buf.seek(0)
img_b64 = base64.b64encode(buf.read()).decode('utf-8')
return {
"image": f"data:image/png;base64,{img_b64}",
"stats": {
"min": float(np.min(image_data)),
"max": float(np.max(image_data)),
"vmin_used": float(vmin),
"vmax_used": float(vmax),
"cmap_used": params.cmap
}
}
@app.get("/download_fits")
def download_fits(type: str):
if type in state.results:
data_to_save = state.results[type]
elif type == 'data' and state.data is not None:
data_to_save = state.data
else:
raise HTTPException(status_code=400, detail="Data not found")
hdu = fits.PrimaryHDU(data=data_to_save, header=state.header)
buf = io.BytesIO()
hdu.writeto(buf)
buf.seek(0)
return Response(content=buf.read(), media_type="application/octet-stream", headers={"Content-Disposition": f"attachment; filename=result_{type}.fits"})
@app.get("/query_simbad")
def query_simbad():
if not ASTROQUERY_AVAILABLE:
raise HTTPException(status_code=501, detail="Librería 'astroquery' no instalada.")
if state.header is None:
raise HTTPException(status_code=400, detail="No header loaded.")
try:
wcs = WCS(state.header)
if wcs.naxis > 2:
wcs = wcs.celestial
nx = state.header.get('NAXIS1', 0)
ny = state.header.get('NAXIS2', 0)
center_sky = wcs.pixel_to_world(nx / 2, ny / 2)
custom_simbad = Simbad()
custom_simbad.add_votable_fields('otype', 'flux(V)', 'distance')
result_table = custom_simbad.query_region(center_sky, radius=2 * u.arcmin)
if result_table is None:
return {"found": False, "data": []}
json_data = []
for row in result_table:
item = {}
for col in result_table.colnames:
val = row[col]
if isinstance(val, bytes):
val = val.decode('utf-8')
if np.ma.is_masked(val):
val = ""
if isinstance(val, (np.integer, int)):
val = int(val)
elif isinstance(val, (np.floating, float)):
val = float(val)
item[col] = val
json_data.append(item)
return {"found": True, "data": json_data}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
app.mount("/assets", StaticFiles(directory=os.path.join(STATIC_DIR, "assets")), name="assets")
@app.get("/{full_path:path}")
async def serve_react_app(full_path: str):
potential_file = os.path.join(STATIC_DIR, full_path)
if os.path.isfile(potential_file):
response = FileResponse(potential_file)
else:
response = FileResponse(os.path.join(STATIC_DIR, "index.html"))
response.headers["Cache-Control"] = "no-cache, no-store, must-revalidate, max-age=0"
response.headers["Pragma"] = "no-cache"
response.headers["Expires"] = "0"
return response
def get_free_port():
port = 8000
while True:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
if s.connect_ex(('localhost', port)) != 0:
return port
port += 1
[docs]
def start_server():
port = get_free_port()
url = f"http://localhost:{port}"
print("\n" + "="*50)
print(" DISCO GUI IS RUNNING")
print(f" URL: {url}")
print(" Press Ctrl+C to stop the server safely")
print("="*50 + "\n")
def open_browser():
time.sleep(1.5)
webbrowser.open(url)
threading.Thread(target=open_browser, daemon=True).start()
uvicorn.run(app, host="0.0.0.0", port=port, log_level="warning")
if __name__ == "__main__":
start_server()