Files
2026-03-09 22:30:41 +00:00

136 lines
4.8 KiB
Mojo
Executable File

# ===----------------------------------------------------------------------=== #
# Copyright (c) 2025, Modular Inc. All rights reserved.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions:
# https://llvm.org/LICENSE.txt
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===----------------------------------------------------------------------=== #
from math import ceildiv
from os import abort
from sys import has_accelerator
from complex import ComplexSIMD, ComplexScalar
from gpu import global_idx
from gpu.host import DeviceContext
from layout import Layout, LayoutTensor
from python import PythonObject
from python.bindings import PythonModuleBuilder
comptime GRID_WIDTH = 60
comptime GRID_HEIGHT = 25
comptime float_dtype = DType.float32
comptime int_dtype = DType.int32
comptime MIN_X: Scalar[float_dtype] = -2.0
comptime MAX_X: Scalar[float_dtype] = 0.7
comptime MIN_Y: Scalar[float_dtype] = -1.12
comptime MAX_Y: Scalar[float_dtype] = 1.12
comptime layout = Layout.row_major(GRID_HEIGHT, GRID_WIDTH)
# An interface for this Mojo module must be exported to Python.
@export
fn PyInit_mandelbrot_mojo() -> PythonObject:
try:
# A Python module is constructed, matching the name of this Mojo module.
var module = PythonModuleBuilder("mandelbrot_mojo")
# The functions to be exported are registered within this module.
module.def_function[run_mandelbrot]("run_mandelbrot")
return module.finalize()
except e:
abort(String("failed to create Python module: ", e))
fn run_mandelbrot(iterations: PythonObject) raises -> PythonObject:
"""The main GPU dispatch function for the Mandelbrot calculation called from Python.
"""
constrained[has_accelerator(), "This example requires a supported GPU"]()
# Get the context for the attached GPU
var ctx = DeviceContext()
# Allocate a tensor on the target device to hold the resulting set.
var dev_buf = ctx.enqueue_create_buffer[int_dtype](comptime (layout.size()))
var out_tensor = LayoutTensor[int_dtype, layout](dev_buf)
# Compute how many blocks are needed in each dimension to fully cover the grid,
# rounding up to ensure even partially filled blocks are launched.
comptime BLOCK_SIZE = 16
comptime COL_BLOCKS = ceildiv(GRID_WIDTH, BLOCK_SIZE)
comptime ROW_BLOCKS = ceildiv(GRID_HEIGHT, BLOCK_SIZE)
# Launch the Mandelbrot kernel on the GPU with a 2D grid of thread blocks.
ctx.enqueue_function[mandelbrot, mandelbrot](
out_tensor,
Int32(py=iterations),
grid_dim=(COL_BLOCKS, ROW_BLOCKS),
block_dim=(BLOCK_SIZE, BLOCK_SIZE),
)
ctx.synchronize()
# Map the output tensor data to CPU so that we can read the results.
with dev_buf.map_to_host() as host_buf:
var host_tensor = LayoutTensor[int_dtype, layout](host_buf)
# Return the ASCII art string representation to Python.
return draw_mandelbrot(host_tensor, Int(py=iterations))
fn mandelbrot(
tensor: LayoutTensor[int_dtype, layout, MutAnyOrigin], iterations: Int32
):
"""The per-element calculation of iterations to escape in the Mandelbrot set.
"""
# Obtain the position in the grid from the X, Y thread locations.
var row = global_idx.y
var col = global_idx.x
comptime SCALE_X = (MAX_X - MIN_X) / GRID_WIDTH
comptime SCALE_Y = (MAX_Y - MIN_Y) / GRID_HEIGHT
# Calculate the complex C corresponding to that grid location.
var cx = MIN_X + Float32(col) * SCALE_X
var cy = MIN_Y + Float32(row) * SCALE_Y
var c = ComplexScalar[float_dtype](cx, cy)
# Perform the Mandelbrot iteration loop calculation.
var z = ComplexScalar[float_dtype](0, 0)
var iters = Scalar[int_dtype](0)
var in_set_mask = Scalar[DType.bool](True)
for _ in range(iterations):
if not any(in_set_mask):
break
in_set_mask = z.squared_norm().le(4)
iters = in_set_mask.select(iters + 1, iters)
z = z.squared_add(c)
# Write out the resulting iterations to escape.
tensor[row, col] = iters
def draw_mandelbrot(
tensor: LayoutTensor[int_dtype, layout], iterations: Int32
) -> String:
"""A helper function to visualize the Mandelbrot set in ASCII art."""
comptime sr = StringSlice("....,c8M@jawrpogOQEPGJ")
var buffer = String()
for row in range(GRID_HEIGHT):
for col in range(GRID_WIDTH):
var v = tensor[row, col]
if v < iterations:
var idx = Int(v % len(sr))
var p = sr[byte=idx]
buffer += p
else:
buffer += " "
buffer += "\n"
return buffer