# Copyright (C) 2024 - 2026 ANSYS, Inc. and/or its affiliates.
# SPDX-License-Identifier: MIT
#
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
"""
.. _ref_wf_ml_01_ion_trap_modelling:

Q3D - Simplified ion trap modeling
##################################

Generate a 3D BEM model of a 3-rail electrode Ion Trap to identify the nodal point and
then optimize the grating coupler design with PyAnsys (PyAEDT and PyLumerical).

Problem description
-------------------

Surface electrodes adjacent to grating couplers render an integrated ion trap.
In this workflow surface electrodes are modeled using ANSYS Q3D and grating couplers using
ANSYS Lumerical. ANSYS Q3D CG solver, based on the boundary element method, allows to
evaluate the ion trap height.Then the coordinates of the ion trap are passed to an optimization
algorithm to define the optimal two-dimensional grating coupler design, which will focus the laser
beam at the ion trap heigth.

The workflow is explained in detail in this article:
https://optics.ansys.com/hc/en-us/articles/20715978394131-Integrated-Ion-Traps-using-Surface-Electrodes-and-Grating-Couplers

The workflow consists of these steps:
- Set up the Q3D parametric model.
- Identify the electric field node point for each design point.
- Export the node coordinates for the subsequent Lumerical step.
- Launch the Lumerical scripts.

"""  # noqa: D400, D415

# Perform required imports
# ------------------------

import os
from pathlib import Path
import shutil
import tempfile
import time

from PIL import Image
from ansys.aedt.core import Q3d
from ansys.lumerical.core import FDTD

# sphinx_gallery_start_ignore
# Check if the __file__ variable is defined. If not, set it.
# This is a workaround to run the script in Sphinx-Gallery.
if "__file__" not in locals():
    __file__ = Path(os.getcwd(), "wf_q3l_01_ion_trap_modelling.py")
# sphinx_gallery_end_ignore

###############################################################################
# Prepare and launch Q3D
# ----------------------
# Define constants.

AEDT_VERSION = os.getenv("AEDT_VERSION", "2025.2")  # Set your AEDT version here
NUM_CORES = 4
NG_MODE = (
    os.getenv("ON_CI", "false").lower() == "true"
)  # Open AEDT UI when it is launched, unless running in CI
NODE_FILENAME = "NodePositionTable.tab"
LEGEND_FILENAME = "legend.txt"
PARENT_DIR_PATH = Path(__file__).parent.absolute()

# Create temporary directory.

temp_folder = tempfile.TemporaryDirectory(suffix=".ansys")
lumerical_script_folder = Path(temp_folder.name)  # / "lumerical_scripts"
node_path = lumerical_script_folder / NODE_FILENAME
legend_path = lumerical_script_folder / LEGEND_FILENAME

# Launch AEDT and start a Maxwell 2D design.

project_name = os.path.join(temp_folder.name, "IonTrapQ3D.aedt")
q3d = Q3d(
    project=project_name,
    design="01_Q3D_IonTrap_3rails",
    version=AEDT_VERSION,
    non_graphical=NG_MODE,
    new_desktop=True,
)
q3d.modeler.model_units = "um"

###############################################################################
# Preprocess
# ----------
# The preprocessing is performed using the following steps:
# 1. Define design variables.
# 2. Create design geometry.
# 3. Define excitations.
# 4. (Optional) Define mesh settings.

# Initialize dictionaries for design variables.

geom_params = {
    "div": str(73 / 41),
    "w_rf": "41um",
    "w_dc": "41um*div",
    "w_cut": "4um",
    "metal_thickness": "1um",
    "offset_glass": "50um",
    "glass_thickness": "10um",
    "x_dummy": "2um",
    "y_dummy": "300um",
    "Z_length": "300um",
}

# Define design variables from dictionaries

for k, v in geom_params.items():
    q3d[k] = v

# Create design geometry

dc = q3d.modeler.create_rectangle(
    orientation="XY",
    origin=["-w_dc/2", "-metal_thickness/2", "0"],
    sizes=["w_dc", "metal_thickness"],
    name="DC",
    material="aluminum",
)
# dc.color = (0, 0, 255)  # rgb

gnd = q3d.modeler.create_rectangle(
    orientation="XY",
    origin=["-(w_dc/2+w_cut+w_rf+offset_glass)", "-(metal_thickness/2+glass_thickness)", "0"],
    sizes=["2*(w_dc/2+w_cut+w_rf+offset_glass)", "-metal_thickness"],
    name="gnd",
    material="aluminum",
)
rf = q3d.modeler.create_rectangle(
    orientation="XY",
    origin=["-(w_dc/2+w_cut+w_rf)", "-metal_thickness/2", "0"],
    sizes=["w_rf", "metal_thickness"],
    name="RF",
    material="aluminum",
)
sub_glass = q3d.modeler.create_rectangle(
    orientation="XY",
    origin=["-(w_dc/2+w_cut+w_rf+offset_glass)", "-metal_thickness/2", "0"],
    sizes=["2*(w_dc/2+w_cut+w_rf+offset_glass)", "-glass_thickness"],
    name="substrate_glass",
    material="glass",
)
ins = q3d.modeler.create_rectangle(
    orientation="XY",
    origin=["-(w_dc/2+w_cut)", "-metal_thickness/2", "0"],
    sizes=["w_cut", "metal_thickness"],
    name="ins",
    material="vacuum",
)

# Create dummy objects for mesh and center line for postprocessing and region

dummy = q3d.modeler.create_rectangle(
    orientation="XY",
    origin=["0", "metal_thickness/2", "0"],
    sizes=["-x_dummy", "y_dummy"],
    name="dummy",
    material="vacuum",
)

# Extrude in z-direction

q3d.modeler.sweep_along_vector(
    assignment=q3d.modeler._get_model_objects(),
    sweep_vector=[0, 0, "Z_length"],
    draft_angle=0,
    draft_type="Round",
)

# Create center line for post-processing

center_line_length = 300 * 1e-6  # 300 um
center_line_length_str = str(center_line_length * 1e6)  # in um
mid_center_line_length_str = str(0.5 * center_line_length * 1e6)  # in um
center_line = q3d.modeler.create_polyline(
    points=[
        ["0", "metal_thickness/2", str(mid_center_line_length_str) + "um"],
        [
            "0",
            "metal_thickness/2+" + center_line_length_str + "um",
            str(mid_center_line_length_str) + "um",
        ],
    ],
    name="center_line",
)

# Define excitations

q3d.auto_identify_nets()

# Define mesh settings

q3d.mesh.assign_initial_mesh(method="AnsoftClassic")

# For good quality results, uncomment the following mesh operations lines
#
# q3d.mesh.assign_length_mesh(assignment=center_line.id,
#                             maximum_length=1e-7,
#                             maximum_elements=None,
#                             name="center_line_0.1um")
# q3d.mesh.assign_length_mesh(assignment=dummy.name,
#                             maximum_length=2e-6,
#                             maximum_elements=1e6,
#                             name="dummy_2um")
# q3d.mesh.assign_length_mesh(assignment=ins.id,
#                             maximum_length=8e-7,
#                             inside_selection=False,
#                             maximum_elements=1e6,
#                             name="ins_0.8um")
# q3d.mesh.assign_length_mesh(assignment=[dc.id, rf.id],
#                             maximum_length=5e-6,
#                             inside_selection=False,
#                             maximum_elements=1e6,
#                             name="dc_5um")
# q3d.mesh.assign_length_mesh(assignment=gnd.id,
#                             maximum_length=1e-5,
#                             inside_selection=False,
#                             maximum_elements=1e6,
#                             name="gnd_10um")

# Duplicate structures and assignments to complete the model

q3d.modeler.duplicate_and_mirror(
    assignment=[rf.id, dummy.id, ins.id],
    origin=["0", "0", "0"],
    vector=["-1", "0", "0"],
    duplicate_assignment=True,
)

###############################################################################
# Run simulation and parametric sweep
# -----------------------------------
# Create, validate, and analyze setup

setup_name = "MySetupAuto"
setup1 = q3d.create_setup(props={"Name": setup_name, "AdaptiveFreq": "1Hz", "SaveFields": True})
setup1.ac_rl_enabled = False
setup1.dc_enabled = False
setup1.update()
q3d.validate_simple()
q3d.analyze_setup(name=setup_name, use_auto_settings=False, cores=NUM_CORES)

#  Create and solve parametric sweep
#  Keeping w_rf constant, recompute the w_dc values from the desired ratios w_rf/w_dc

div_sweep_start = 1.4
div_sweep_stop = 2
sweep = q3d.parametrics.add(
    variable="div",
    start_point=div_sweep_start,
    end_point=div_sweep_stop,
    step=0.2,
    variation_type="LinearStep",
    name="w_dc_sweep",
)
add_points = [1, 1.3]
[
    sweep.add_variation(sweep_variable="div", start_point=p, variation_type="SingleValue")
    for p in add_points
]
sweep["SaveFields"] = True
sweep.analyze(cores=NUM_CORES)

###############################################################################
# Postprocess
# -----------
# Create the Ey expression in the PyAEDT advanced field calculator.
# Due to the symmetric nature of this specific geometry, the electric field
# node will be located along the center line. The electric field node is the
# point where the Ey will be zero and can be found directly by Q3D
# postprocessing features.

# Edit sources to scale the solution for the actual assigned potentials.

sources_cg = {
    "DC": ("0V", "0deg"),
    "gnd": ("0V", "0deg"),
    "RF": ("1V", "0deg"),
    "RF_1": ("1V", "0deg"),
}
q3d.edit_sources(sources_cg)

# Evaluate the E- fieled on the control line and export nodal points

line_name = "Line1"
q3d.insert_em_field_line(assignment="center_line", points=1000, name=line_name)
my_plots = q3d.post.create_report(
    expressions="re(EY)",
    primary_sweep_variable="NormalizedDistance",
    report_category="Static EM Fields",
    plot_type="Rectangular Plot",
    context=line_name,
    plot_name="my_plot",
)
my_plots.edit_x_axis_scaling(min_scale="0.01", max_scale="1")
my_plots.update_trace_in_report(
    my_plots.get_solution_data().expressions, variations={"div": ["All"]}, context=line_name
)

# Identify the zero point for each trace

my_plots.add_cartesian_y_marker("0")
my_plots.add_trace_characteristics(
    "XAtYVal", arguments=["0"], solution_range=["Full", "0.01", "1.0"]
)

# Export the points at which Ey=0 to a TXT file

my_plots.edit_general_settings(use_scientific_notation=True)
my_plots.export_table_to_file(my_plots.plot_name, str(node_path), "Legend")

###############################################################################
# Prepare and run Lumerical simulation
# ------------------------------------
# Edit the file outputted by Q3D to be read in by Lumerical

new_line = []
with open(node_path, "r", encoding="utf-8") as f:
    lines = f.readlines()

new_line.append(lines[0])
for line in lines[1:]:
    new_line.append(line.split("\t")[0])
    new_line.append("\n" + line.split("\t")[1].lstrip())

with open(legend_path, "w", encoding="utf-8") as f:
    for line in new_line:
        f.write(line)

# Copy Lumerical scripts and illustration to the local folder

gc_farfield_path = shutil.copy(PARENT_DIR_PATH / "GC_farfield.lsf", lumerical_script_folder)
gc_opt_path = shutil.copy(PARENT_DIR_PATH / "GC_Opt.lsf", lumerical_script_folder)
read_data_path = shutil.copy(PARENT_DIR_PATH / "Readata.lsf", lumerical_script_folder)
img_path = shutil.copy(PARENT_DIR_PATH / "img_001.jpg", lumerical_script_folder)

# Start the Lumerical process

gc_0 = FDTD(gc_opt_path)

# Run the first script: build geometry and run optimization

gc_1 = FDTD(read_data_path)
print(
    "Optimize for the nodal point located",
    str(gc_1.getv("T5")),
    "um, above the linearly apodized grating coupler",
)

# Run the optimized design

gc_2 = FDTD(str(lumerical_script_folder / "Testsim_Intensity_best_solution"))
gc_2.save(str(lumerical_script_folder / "GC_farfields_calc"))
gc_2.run()

# Run the second script for calculating plots

gc_2.feval(gc_farfield_path)
print(f"Target focal distance of output laser beam: {gc_2.getv('Mselect') * 1000000} (um)")
print(f"Actual focal distance for the optimised geometry: {gc_2.getv('Mactual') * 1000000} (um)")
print(f"Relative error: {gc_2.getv('RelVal') * 100}%")
print(f"FWHM of vertical direction at focus: {gc_2.getv('FWHM_X') * 1000000} (um)")
print(f"FWHM of horizontal direction at focus {gc_2.getv('FWHM_Y') * 1000000} (um)")
print(f"Substrate material : {gc_2.getv('Material')}")

print(f"Waveguide etch depth: {gc_2.getv('GC_etch') * 1000000000} (nm)")
print(f"Grating period (P): {gc_2.getv('GC_period') * 1000000000} (nm)")
print(f"Grating minimum duty cycle: {gc_2.getv('GC_DCmin')}")

# Display grating schema image


def in_ipython():
    try:
        from IPython import get_ipython

        return get_ipython() is not None
    except ImportError:
        return False


schema_img = Image.open(PARENT_DIR_PATH / "img_001.jpg")

# Show image inside IPython / Jupyter
if in_ipython():
    from IPython.display import display

    display(schema_img)
# Show image using default image viewer
else:
    schema_img.show()

schema_img.close()

###############################################################################
# Exit the solver
# ---------------
# Close FDTD projects and release AEDT
#

gc_0.close()
gc_1.close()
gc_2.close()
q3d.save_project()
q3d.desktop_class.release_desktop()

# Wait three seconds to allow AEDT to shut down before cleaning the temporary directory
time.sleep(3)

# Clean up the temporary folder
temp_folder.cleanup()
