1"""Create an RGB image matrix of the observation using objects found by
2:doc:`skysim.query <query>` and image configuration from an
3`~skysim.settings.ImageSettings` object.
4"""
5
6# License: GPLv3+ (see COPYING); Copyright (C) 2025 Tai Withers
7
8from datetime import datetime, time, timedelta
9from multiprocessing import Pool, cpu_count
10
11import numpy as np
12from astropy import units as u
13from astropy.coordinates import SkyCoord
14from astropy.table import QTable, Row, vstack
15from matplotlib.colors import LinearSegmentedColormap
16from numpy.typing import ArrayLike
17from pydantic import NonNegativeFloat, PositiveInt
18
19from skysim.colours import RGBTuple
20from skysim.settings import ImageSettings
21from skysim.utils import (
22 FloatArray,
23 IntArray,
24 round_columns,
25)
26
27# Constants
28
29
30MINIMUM_BRIGHTNESS = 0.2
31"""Minimum brightness for an object. An object with the highest allowed
32magnitude would have this brightness in order to keep it visible against the (0
33brightness) backdrop."""
34
35
36# Methods
37
38
39## Top-Level Populate Method
40
41
[docs]
42def create_image_matrix(
43 image_settings: ImageSettings,
44 planet_tables: list[QTable],
45 star_table: QTable,
46 verbose_level: int,
47) -> FloatArray:
48 """Primary function for the populate module. Creates and fills in the image
49 matrix.
50
51 Parameters
52 ----------
53 image_settings : ImageSettings
54 Configuration needed.
55 planet_tables : list[astropy.table.QTable]
56 Result of planet queries.
57 star_table : astropy.table.QTable
58 Result of SIMBAD queries.
59 verbose_level : int
60 How much detail to print.
61
62 Returns
63 -------
64 FloatArray
65 Array of RGB image frames.
66 """
67 image_matrix = get_empty_image(image_settings.frames, image_settings.image_pixels)
68
69 # fill all the backgrounds
70 for i in range(image_settings.frames):
71 background_colour = get_timed_background_colour(
72 image_settings.colour_mapping, image_settings.local_datetimes[i]
73 )
74 image_matrix[i] = fill_frame_background(background_colour, image_matrix[i])
75
76 # prepare tables for each frame
77 object_tables = [
78 prepare_object_table(image_settings, star_table, planet_tables, i)
79 for i in range(image_settings.frames)
80 ]
81
82 # add in all the objects
83 with Pool(cpu_count() - 1) as pool:
84 filled_frames = pool.starmap(
85 fill_frame_objects,
86 [
87 (i, image_matrix[i], object_tables[i], image_settings, verbose_level)
88 for i in range(image_settings.frames)
89 if len(object_tables[i]) > 0
90 ],
91 )
92
93 # re-sort the frames
94 for index, frame in filled_frames:
95 image_matrix[index] = frame
96
97 image_matrix = np.moveaxis(image_matrix, 1, -1) # put the RGB axis at the end
98
99 image_matrix = np.flip(image_matrix, axis=2) # put the x-axis the right way round
100
101 return image_matrix
102
103
104## Helper Methods
105
106
[docs]
107def get_empty_image(
108 frames: PositiveInt,
109 image_pixels: PositiveInt,
110) -> FloatArray:
111 """Initialize an empty image array.
112
113 Parameters
114 ----------
115 frames : pydantic.PositiveInt
116 The number of frames to include.
117 image_pixels : pydantic.PositiveInt
118 The width/height of the image in pixels.
119
120 Returns
121 -------
122 FloatArray
123 An array of zeros.
124 """
125 return np.zeros(
126 (
127 frames,
128 3,
129 image_pixels,
130 image_pixels,
131 )
132 )
133
134
[docs]
135def get_seconds_from_midnight(local_time: time | datetime) -> NonNegativeFloat:
136 """Calculate the number of seconds since midnight for some time value.
137
138 Parameters
139 ----------
140 local_time : datetime.time
141 Python time value, can be naive.
142
143 Returns
144 -------
145 pydantic.NonNegativeFloat
146 Number of seconds.
147 """
148 delta_midnight = timedelta(
149 hours=local_time.hour,
150 minutes=local_time.minute,
151 microseconds=local_time.microsecond,
152 )
153 return delta_midnight.total_seconds()
154
155
[docs]
156def get_timed_background_colour(
157 background_colours: LinearSegmentedColormap,
158 local_datetime: datetime,
159) -> RGBTuple:
160 """Get the background colour for the image based on the colour-time mapping.
161
162 Parameters
163 ----------
164 background_colours : matplotlib.colors.LinearSegmentedColormap
165 Colourmap linking floats [0,1] to RGB values.
166 local_datetime : datetime.datetime
167 Local time of the observation.
168
169 Returns
170 -------
171 RGBTuple
172 Colour corresponding to `local_datetime`.
173 """
174 day_percentage = get_seconds_from_midnight(local_datetime) / (24 * 60 * 60)
175
176 return background_colours(day_percentage)[:-1]
177
178
[docs]
179def get_timed_magnitude(
180 magnitude_mapping: FloatArray, local_datetime: datetime
181) -> float:
182 """Get the maximum magnitude value visible for a current time.
183
184 Parameters
185 ----------
186 magnitude_mapping : FloatArray
187 Array with size [seconds per day] and values [viewable magnitudes].
188 local_datetime : datetime.datetime
189 Local time of the observation.
190
191 Returns
192 -------
193 float
194 Magnitude value corresponding to `local_datetime`.
195 """
196 index = int(get_seconds_from_midnight(local_datetime))
197 return magnitude_mapping[index]
198
199
[docs]
200def fill_frame_background(colour: RGBTuple, frame_matrix: FloatArray) -> FloatArray:
201 """Fill an RGB image with a colour.
202
203 Parameters
204 ----------
205 colour : RGBTuple
206 RGB values.
207 frame_matrix : FloatArray
208 Array with shape (3, X, Y) to be filled.
209
210 Returns
211 -------
212 FloatArray
213 Filled array.
214 """
215 fake_frame = np.ones_like(frame_matrix).T
216 return np.swapaxes(fake_frame * colour, 0, -1)
217
218
[docs]
219def filter_objects_brightness(
220 maximum_magnitude: float, objects_table: QTable
221) -> QTable:
222 """Filter a table of objects by the maximum magitude that can be seen.
223
224 Parameters
225 ----------
226 maximum_magnitude : float
227 Highest (inclusive) value for magnitude.
228 objects_table : astropy.table.QTable
229 Table to be filtered. Should have a "magnitude" column.
230
231 Returns
232 -------
233 astropy.table.QTable
234 Filtered table.
235 """
236 indices = objects_table["magnitude"] <= maximum_magnitude
237 return objects_table[indices]
238
239
[docs]
240def filter_objects_fov(
241 radec: SkyCoord,
242 fov: u.Quantity["angle"], # type: ignore[type-arg,name-defined]
243 objects_table: QTable,
244) -> QTable:
245 """Filter a table of objects by their distance to a point.
246
247 Parameters
248 ----------
249 radec : astropy.coordinates.SkyCoord
250 Point of observations.
251 fov : astropy.units.Quantity[angle]
252 Field of view (2x visible radius).
253 objects_table : astropy.table.QTable
254 Table of objects to be filtered.
255
256 Returns
257 -------
258 astropy.table.QTable
259 Filtered table.
260 """
261
262 object_separations = objects_table["skycoord"].separation(radec)
263
264 maximum_separation = (
265 fov / 2 * 1.01
266 ) # add 1% buffer to capture light from objects near the edge of the frame
267 indices = object_separations <= maximum_separation
268 return objects_table[indices]
269
270
[docs]
271def magnitude_to_flux(magnitude: ArrayLike) -> ArrayLike:
272 """Magnitude to flux conversion (relative to some reference value).
273
274 Parameters
275 ----------
276 magnitude : numpy.typing.ArrayLike
277 Apparent magnitude.
278
279 Returns
280 -------
281 numpy.typing.ArrayLike
282 Relative flux.
283 """
284 return 10 ** (-magnitude / 2.5) # type: ignore[operator]
285
286
[docs]
287def linear_rescale(
288 data: ArrayLike, new_min: float = 0, new_max: float = 1
289) -> ArrayLike:
290 """Generic function to linearly scale data between some minimum and maximum.
291
292 Parameters
293 ----------
294 data : numpy.typing.ArrayLike
295 Data to be scaled.
296 new_min : float, optional
297 New minimum, by default 0.
298 new_max : float, optional
299 New maximum, by default 1.
300
301 Returns
302 -------
303 numpy.typing.ArrayLike
304 Scaled data.
305 """
306 data_min, data_max = np.min(data), np.max(data)
307 data_range = data_max - data_min
308 if data_range == 0:
309 data_range = 1
310 new_range = new_max - new_min
311
312 return (data - data_min) * (new_range / data_range) + new_min
313
314
[docs]
315def get_scaled_brightness(object_table: QTable) -> QTable:
316 """Add a new column to `object_table` with a relative [0,1] brightness value
317 based on the "magnitude" column.
318
319 Parameters
320 ----------
321 object_table : astropy.table.QTable
322 Table of objects.
323
324 Returns
325 -------
326 astropy.table.QTable
327 Table with added "brightness" column.
328 """
329 object_table["flux"] = magnitude_to_flux(object_table["magnitude"])
330 object_table["brightness"] = np.log10(
331 object_table["flux"]
332 ) # since humans see brightness log-scaled
333
334 object_table["brightness"] = linear_rescale(
335 object_table["brightness"], new_min=MINIMUM_BRIGHTNESS, new_max=1
336 )
337
338 object_table.remove_column("flux")
339 return round_columns(object_table, ["brightness"])
340
341
[docs]
342def pixel_in_frame(xy: IntArray, image_pixels: int) -> bool:
343 """Check if an xy pixel is in a square frame of size `image_pixels`.
344
345 Parameters
346 ----------
347 xy : IntArray
348 Pixel.
349 image_pixels : int
350 Frame size.
351
352 Returns
353 -------
354 bool
355 Whether the pixel is in the frame.
356 """
357 x_in = 0 <= xy[0] < image_pixels
358 y_in = 0 <= xy[1] < image_pixels
359
360 return x_in and y_in
361
362
[docs]
363def add_object_to_frame(
364 object_row: Row,
365 frame: FloatArray,
366 area_mesh: IntArray,
367 brightness_scale_mesh: FloatArray,
368) -> FloatArray:
369 """Add a celestial object to the image.
370
371 Parameters
372 ----------
373 object_row : astropy.table.Row
374 Row of object table.
375 frame : FloatArray
376 RGB image.
377 area_mesh : IntArray
378 Mesh describing the area to which light from a single object can spread.
379 brightness_scale_mesh : FloatArray
380 Mesh describing the dimming of that light.
381
382 Returns
383 -------
384 FloatArray
385 `frame` with the object added in.
386 """
387
388 offset_xy = np.array(
389 [area_mesh[0] + object_row["x"], area_mesh[1] + object_row["y"]]
390 )
391
392 for mesh_xy, _ in np.ndenumerate(area_mesh[0]):
393 frame_xy = offset_xy[:, *mesh_xy] # type: ignore[arg-type]
394 if pixel_in_frame(frame_xy, frame.shape[-1]):
395
396 weight = brightness_scale_mesh[*mesh_xy] * object_row["brightness"]
397 old_rgb = frame[:, *frame_xy]
398
399 new_rgb = np.average(
400 [object_row["rgb"], old_rgb], weights=[weight, 1 - weight], axis=0
401 )
402 frame[:, *frame_xy] = new_rgb
403
404 return frame
405
406
[docs]
407def fill_frame_objects(
408 index: int,
409 frame: FloatArray,
410 objects_table: QTable,
411 image_settings: ImageSettings,
412 verbose_level: int,
413) -> tuple[int, FloatArray]:
414 """Pickle-able function to call `add_object_to_frame` for a whole table of objects.
415
416 Parameters
417 ----------
418 index : int
419 The frame number.
420 frame : FloatArray
421 RGB image.
422 objects_table : astropy.table.QTable
423 Table of objects to add.
424 image_settings : ImageSettings
425 Configuration.
426 verbose_level : int
427 How much detail to print.
428
429 Returns
430 -------
431 tuple[int, FloatArray]
432 Frame number and updated image.
433 """
434 # calculate the xy coordinates for objects for this wcs
435 objects_table["skycoord"] = SkyCoord(
436 ra=objects_table["ra"], dec=objects_table["dec"], unit="deg"
437 )
438 xy = np.flipud(
439 np.round(objects_table["skycoord"].to_pixel(image_settings.wcs_objects[index]))
440 ).astype(int)
441 objects_table["x"] = xy[0]
442 objects_table["y"] = xy[1]
443 objects_table.remove_column("skycoord")
444
445 for row in objects_table:
446 frame = add_object_to_frame(
447 row,
448 frame,
449 image_settings.area_mesh,
450 image_settings.brightness_scale_mesh,
451 )
452
453 if verbose_level > 1:
454 print(f"Added {len(objects_table)} objects to image {index}.")
455
456 return index, frame
457
458
[docs]
459def prepare_object_table(
460 image_settings: ImageSettings,
461 star_table: QTable,
462 planet_tables: list[QTable],
463 frame: int,
464) -> QTable:
465 """Converts the star and planet tables into a single combined unit for a
466 given frame.
467
468 Parameters
469 ----------
470 image_settings : ImageSettings
471 Configuration.
472 star_table : astropy.table.QTable
473 Star table.
474 planet_tables : list[astropy.table.QTable]
475 List of planet tables.
476 frame : int
477 Frame number to generate for.
478
479 Returns
480 -------
481 astropy.table.QTable
482 Combined table.
483 """
484 object_table = vstack([star_table, planet_tables[frame]])
485
486 # filter by magnitude
487 current_maximum_magnitude = get_timed_magnitude(
488 image_settings.magnitude_mapping, image_settings.local_datetimes[frame]
489 )
490 object_table = filter_objects_brightness(current_maximum_magnitude, object_table)
491
492 # filter by FOV
493 object_table["skycoord"] = SkyCoord(ra=object_table["ra"], dec=object_table["dec"])
494 object_table = filter_objects_fov(
495 image_settings.observation_radec[frame],
496 image_settings.field_of_view,
497 object_table,
498 )
499
500 if len(object_table) == 0:
501 return object_table
502
503 # make pickleable
504 object_table["ra"] = object_table["ra"].to(u.deg).data
505 object_table["dec"] = object_table["dec"].to(u.deg).data
506
507 object_table = get_scaled_brightness(object_table)
508
509 object_table["rgb"] = [
510 image_settings.object_colours[stype] for stype in object_table["spectral_type"]
511 ]
512
513 object_table.remove_columns(["id", "magnitude", "spectral_type", "skycoord"])
514
515 return object_table