mirror of
https://github.com/johndoe6345789/metabuilder.git
synced 2026-04-25 06:14:59 +00:00
136 lines
4.8 KiB
Mojo
Executable File
136 lines
4.8 KiB
Mojo
Executable File
# ===----------------------------------------------------------------------=== #
|
|
# Copyright (c) 2025, Modular Inc. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License v2.0 with LLVM Exceptions:
|
|
# https://llvm.org/LICENSE.txt
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ===----------------------------------------------------------------------=== #
|
|
|
|
from math import ceildiv
|
|
from os import abort
|
|
from sys import has_accelerator
|
|
|
|
from complex import ComplexSIMD, ComplexScalar
|
|
from gpu import global_idx
|
|
from gpu.host import DeviceContext
|
|
from layout import Layout, LayoutTensor
|
|
from python import PythonObject
|
|
from python.bindings import PythonModuleBuilder
|
|
|
|
comptime GRID_WIDTH = 60
|
|
comptime GRID_HEIGHT = 25
|
|
|
|
comptime float_dtype = DType.float32
|
|
comptime int_dtype = DType.int32
|
|
|
|
comptime MIN_X: Scalar[float_dtype] = -2.0
|
|
comptime MAX_X: Scalar[float_dtype] = 0.7
|
|
comptime MIN_Y: Scalar[float_dtype] = -1.12
|
|
comptime MAX_Y: Scalar[float_dtype] = 1.12
|
|
|
|
comptime layout = Layout.row_major(GRID_HEIGHT, GRID_WIDTH)
|
|
|
|
|
|
# An interface for this Mojo module must be exported to Python.
|
|
@export
|
|
fn PyInit_mandelbrot_mojo() -> PythonObject:
|
|
try:
|
|
# A Python module is constructed, matching the name of this Mojo module.
|
|
var module = PythonModuleBuilder("mandelbrot_mojo")
|
|
# The functions to be exported are registered within this module.
|
|
module.def_function[run_mandelbrot]("run_mandelbrot")
|
|
return module.finalize()
|
|
except e:
|
|
abort(String("failed to create Python module: ", e))
|
|
|
|
|
|
fn run_mandelbrot(iterations: PythonObject) raises -> PythonObject:
|
|
"""The main GPU dispatch function for the Mandelbrot calculation called from Python.
|
|
"""
|
|
constrained[has_accelerator(), "This example requires a supported GPU"]()
|
|
|
|
# Get the context for the attached GPU
|
|
var ctx = DeviceContext()
|
|
|
|
# Allocate a tensor on the target device to hold the resulting set.
|
|
var dev_buf = ctx.enqueue_create_buffer[int_dtype](comptime (layout.size()))
|
|
var out_tensor = LayoutTensor[int_dtype, layout](dev_buf)
|
|
|
|
# Compute how many blocks are needed in each dimension to fully cover the grid,
|
|
# rounding up to ensure even partially filled blocks are launched.
|
|
comptime BLOCK_SIZE = 16
|
|
comptime COL_BLOCKS = ceildiv(GRID_WIDTH, BLOCK_SIZE)
|
|
comptime ROW_BLOCKS = ceildiv(GRID_HEIGHT, BLOCK_SIZE)
|
|
|
|
# Launch the Mandelbrot kernel on the GPU with a 2D grid of thread blocks.
|
|
ctx.enqueue_function[mandelbrot, mandelbrot](
|
|
out_tensor,
|
|
Int32(py=iterations),
|
|
grid_dim=(COL_BLOCKS, ROW_BLOCKS),
|
|
block_dim=(BLOCK_SIZE, BLOCK_SIZE),
|
|
)
|
|
ctx.synchronize()
|
|
|
|
# Map the output tensor data to CPU so that we can read the results.
|
|
with dev_buf.map_to_host() as host_buf:
|
|
var host_tensor = LayoutTensor[int_dtype, layout](host_buf)
|
|
# Return the ASCII art string representation to Python.
|
|
return draw_mandelbrot(host_tensor, Int(py=iterations))
|
|
|
|
|
|
fn mandelbrot(
|
|
tensor: LayoutTensor[int_dtype, layout, MutAnyOrigin], iterations: Int32
|
|
):
|
|
"""The per-element calculation of iterations to escape in the Mandelbrot set.
|
|
"""
|
|
# Obtain the position in the grid from the X, Y thread locations.
|
|
var row = global_idx.y
|
|
var col = global_idx.x
|
|
|
|
comptime SCALE_X = (MAX_X - MIN_X) / GRID_WIDTH
|
|
comptime SCALE_Y = (MAX_Y - MIN_Y) / GRID_HEIGHT
|
|
|
|
# Calculate the complex C corresponding to that grid location.
|
|
var cx = MIN_X + Float32(col) * SCALE_X
|
|
var cy = MIN_Y + Float32(row) * SCALE_Y
|
|
var c = ComplexScalar[float_dtype](cx, cy)
|
|
|
|
# Perform the Mandelbrot iteration loop calculation.
|
|
var z = ComplexScalar[float_dtype](0, 0)
|
|
var iters = Scalar[int_dtype](0)
|
|
|
|
var in_set_mask = Scalar[DType.bool](True)
|
|
for _ in range(iterations):
|
|
if not any(in_set_mask):
|
|
break
|
|
in_set_mask = z.squared_norm().le(4)
|
|
iters = in_set_mask.select(iters + 1, iters)
|
|
z = z.squared_add(c)
|
|
|
|
# Write out the resulting iterations to escape.
|
|
tensor[row, col] = iters
|
|
|
|
|
|
def draw_mandelbrot(
|
|
tensor: LayoutTensor[int_dtype, layout], iterations: Int32
|
|
) -> String:
|
|
"""A helper function to visualize the Mandelbrot set in ASCII art."""
|
|
comptime sr = StringSlice("....,c8M@jawrpogOQEPGJ")
|
|
var buffer = String()
|
|
for row in range(GRID_HEIGHT):
|
|
for col in range(GRID_WIDTH):
|
|
var v = tensor[row, col]
|
|
if v < iterations:
|
|
var idx = Int(v % len(sr))
|
|
var p = sr[byte=idx]
|
|
buffer += p
|
|
else:
|
|
buffer += " "
|
|
buffer += "\n"
|
|
return buffer
|