From 0d41cb183679635f4ebe299a4009da1927ffd050 Mon Sep 17 00:00:00 2001 From: Richard Ward Date: Thu, 17 Jul 2025 11:46:05 +0100 Subject: [PATCH] Add Primitive base class and refactor primitives --- parametric_cad/__init__.py | 2 + parametric_cad/primitives/base.py | 34 +++++++++++ parametric_cad/primitives/box.py | 18 +++--- parametric_cad/primitives/cylinder.py | 33 ++++------- parametric_cad/primitives/gear.py | 81 +++++++++++++++++++-------- parametric_cad/primitives/sphere.py | 22 ++++---- parametric_cad/primitives/sprocket.py | 46 +++++++++------ tests/test_primitives.py | 9 +++ 8 files changed, 159 insertions(+), 86 deletions(-) create mode 100644 parametric_cad/primitives/base.py diff --git a/parametric_cad/__init__.py b/parametric_cad/__init__.py index bf20275..f931f8d 100644 --- a/parametric_cad/__init__.py +++ b/parametric_cad/__init__.py @@ -2,6 +2,7 @@ from .core import tm, safe_difference, combine from .geometry import sg, Polygon, Point, box +from .primitives.base import Primitive from .primitives.box import Box from .primitives.cylinder import Cylinder from .primitives.gear import SpurGear @@ -15,6 +16,7 @@ __all__ = [ "sg", "safe_difference", "combine", + "Primitive", "Box", "Cylinder", "Sphere", diff --git a/parametric_cad/primitives/base.py b/parametric_cad/primitives/base.py new file mode 100644 index 0000000..bd9957d --- /dev/null +++ b/parametric_cad/primitives/base.py @@ -0,0 +1,34 @@ +from typing import Optional, Sequence + +from parametric_cad.core import tm + + +class Primitive: + """Base class for simple parametric primitives.""" + + def __init__(self) -> None: + self._position = (0.0, 0.0, 0.0) + self._rotation: Optional[tuple[Sequence[float], float]] = None + + def at(self, x: float, y: float, z: float): + """Translate the primitive to ``(x, y, z)``.""" + self._position = (x, y, z) + return self + + def rotate(self, axis: Sequence[float], angle: float): + """Rotate the primitive around ``axis`` by ``angle`` radians.""" + self._rotation = (axis, angle) + return self + + def _create_mesh(self) -> tm.Trimesh: + """Return the untransformed mesh for this primitive.""" + raise NotImplementedError + + def mesh(self) -> tm.Trimesh: + mesh = self._create_mesh() + if self._rotation is not None: + axis, angle = self._rotation + rot = tm.transformations.rotation_matrix(angle, axis) + mesh.apply_transform(rot) + mesh.apply_translation(self._position) + return mesh diff --git a/parametric_cad/primitives/box.py b/parametric_cad/primitives/box.py index 641cd5e..7188612 100644 --- a/parametric_cad/primitives/box.py +++ b/parametric_cad/primitives/box.py @@ -1,17 +1,13 @@ from parametric_cad.core import tm +from .base import Primitive -class Box: - def __init__(self, width, depth, height): + +class Box(Primitive): + def __init__(self, width: float, depth: float, height: float) -> None: + super().__init__() self.width = width self.depth = depth self.height = height - self._position = (0, 0, 0) - def at(self, x, y, z): - self._position = (x, y, z) - return self - - def mesh(self): - box = tm.creation.box(extents=(self.width, self.depth, self.height)) - box.apply_translation(self._position) - return box + def _create_mesh(self) -> tm.Trimesh: + return tm.creation.box(extents=(self.width, self.depth, self.height)) diff --git a/parametric_cad/primitives/cylinder.py b/parametric_cad/primitives/cylinder.py index 26bf960..f0850a0 100644 --- a/parametric_cad/primitives/cylinder.py +++ b/parametric_cad/primitives/cylinder.py @@ -1,32 +1,19 @@ -from parametric_cad.core import tm -from typing import Sequence, Optional +from typing import Sequence -class Cylinder: - def __init__(self, radius: float, height: float, sections: int = 32): +from parametric_cad.core import tm +from .base import Primitive + + +class Cylinder(Primitive): + def __init__(self, radius: float, height: float, sections: int = 32) -> None: + super().__init__() self.radius = radius self.height = height self.sections = sections - self._position = (0.0, 0.0, 0.0) - self._rotation: Optional[Sequence[float]] = None - def at(self, x: float, y: float, z: float) -> "Cylinder": - self._position = (x, y, z) - return self - - def rotate(self, axis: Sequence[float], angle: float) -> "Cylinder": - """Rotate the cylinder around ``axis`` by ``angle`` radians.""" - self._rotation = (axis, angle) - return self - - def mesh(self) -> tm.Trimesh: - cyl = tm.creation.cylinder( + def _create_mesh(self) -> tm.Trimesh: + return tm.creation.cylinder( radius=self.radius, height=self.height, sections=self.sections, ) - if self._rotation is not None: - axis, angle = self._rotation - rot = tm.transformations.rotation_matrix(angle, axis) - cyl.apply_transform(rot) - cyl.apply_translation(self._position) - return cyl diff --git a/parametric_cad/primitives/gear.py b/parametric_cad/primitives/gear.py index 7ffa762..97fe0a3 100644 --- a/parametric_cad/primitives/gear.py +++ b/parametric_cad/primitives/gear.py @@ -1,15 +1,33 @@ -import numpy as np -from parametric_cad.core import tm, safe_difference -from parametric_cad.geometry import Polygon -from math import pi, sin, cos, tan import logging +from math import cos, pi, sin, tan +from typing import List + +import numpy as np + +from parametric_cad.core import safe_difference, tm +from parametric_cad.geometry import Polygon +from .base import Primitive # Set up logging to file -logging.basicConfig(filename='gear_debug.log', level=logging.DEBUG, - format='%(asctime)s - %(levelname)s - %(message)s') +logging.basicConfig( + filename="gear_debug.log", + level=logging.DEBUG, + format="%(asctime)s - %(levelname)s - %(message)s", +) -class SpurGear: - def __init__(self, module, teeth, width=5.0, bore_diameter=5.0, hole_count=0, hole_diameter=2.0, hole_radius=None): + +class SpurGear(Primitive): + def __init__( + self, + module: float, + teeth: int, + width: float = 5.0, + bore_diameter: float = 5.0, + hole_count: int = 0, + hole_diameter: float = 2.0, + hole_radius: float | None = None, + ) -> None: + super().__init__() self.module = module self.teeth = teeth self.width = width @@ -17,32 +35,39 @@ class SpurGear: self.hole_count = hole_count self.hole_diameter = hole_diameter self.hole_radius = hole_radius or (self.pitch_diameter / 2 + module * 1.5) - logging.debug(f"Initialized SpurGear: module={module}, teeth={teeth}, width={width}, bore={bore_diameter}, holes={hole_count}") + logging.debug( + "Initialized SpurGear: module=%s, teeth=%s, width=%s, bore=%s, holes=%s", + module, + teeth, + width, + bore_diameter, + hole_count, + ) @property - def pitch_diameter(self): + def pitch_diameter(self) -> float: return self.module * self.teeth @property - def base_diameter(self): + def base_diameter(self) -> float: pressure_angle = 20 * pi / 180 return self.pitch_diameter * cos(pressure_angle) @property - def addendum(self): + def addendum(self) -> float: return self.module @property - def dedendum(self): + def dedendum(self) -> float: return 1.25 * self.module - def involute_profile(self, base_radius, outer_radius, steps=10): + def involute_profile(self, base_radius: float, outer_radius: float, steps: int = 10) -> np.ndarray: theta = np.linspace(0, np.arccos(base_radius / outer_radius), steps) x = base_radius * (np.cos(theta) + theta * np.tan(theta)) y = base_radius * (np.sin(theta) - theta * np.tan(theta)) return np.vstack((x, y)).T - def create_tooth(self): + def create_tooth(self) -> np.ndarray: pitch_radius = self.pitch_diameter / 2 base_radius = self.base_diameter / 2 outer_radius = pitch_radius + self.addendum @@ -52,7 +77,7 @@ class SpurGear: mirrored = np.copy(involute) mirrored[:, 1] *= -1 - arc = [] + arc: List[List[float]] = [] arc_steps = 10 start_angle = np.arctan2(mirrored[-1, 1], mirrored[-1, 0]) end_angle = -start_angle @@ -65,18 +90,18 @@ class SpurGear: profile = np.vstack([involute, arc, mirrored[::-1]]) if not np.allclose(profile[0], profile[-1]): profile = np.vstack([profile, profile[0]]) - logging.debug(f"Created tooth profile with {len(profile)} points") + logging.debug("Created tooth profile with %d points", len(profile)) return profile - def mesh(self): + def _create_mesh(self) -> tm.Trimesh: tooth_profile = self.create_tooth() polygon = Polygon(tooth_profile) if not polygon.is_valid: polygon = polygon.buffer(0) logging.warning("Tooth polygon was invalid, repaired with buffer") - tooth_mesh = tm.creation.extrude_polygon(polygon, self.width, engine='triangle') - logging.debug(f"Extruded tooth mesh with {len(tooth_mesh.vertices)} vertices") + tooth_mesh = tm.creation.extrude_polygon(polygon, self.width, engine="triangle") + logging.debug("Extruded tooth mesh with %d vertices", len(tooth_mesh.vertices)) all_teeth = [] for i in range(self.teeth): @@ -84,15 +109,19 @@ class SpurGear: rot = tm.transformations.rotation_matrix(angle, [0, 0, 1]) rotated_tooth = tooth_mesh.copy().apply_transform(rot) all_teeth.append(rotated_tooth) - logging.debug(f"Added tooth {i+1}/{self.teeth}") + logging.debug("Added tooth %d/%d", i + 1, self.teeth) gear_body = tm.util.concatenate(all_teeth) - logging.debug(f"Combined {self.teeth} teeth into gear body with {len(gear_body.vertices)} vertices") + logging.debug( + "Combined %d teeth into gear body with %d vertices", + self.teeth, + len(gear_body.vertices), + ) bore = tm.creation.cylinder(radius=self.bore_diameter / 2, height=self.width + 0.1) bore.apply_translation([0, 0, self.width / 2]) gear = safe_difference(gear_body, bore) - logging.debug(f"Subtracted bore, resulting mesh has {len(gear.vertices)} vertices") + logging.debug("Subtracted bore, resulting mesh has %d vertices", len(gear.vertices)) if self.hole_count > 0: hole_cylinders = [] @@ -106,7 +135,11 @@ class SpurGear: hole = hole.convex_hull hole_cylinders.append(hole) gear = safe_difference(gear, hole_cylinders) - logging.debug(f"Subtracted {self.hole_count} holes, resulting mesh has {len(gear.vertices)} vertices") + logging.debug( + "Subtracted %d holes, resulting mesh has %d vertices", + self.hole_count, + len(gear.vertices), + ) if not gear.is_watertight: logging.warning("Final gear mesh is not watertight") diff --git a/parametric_cad/primitives/sphere.py b/parametric_cad/primitives/sphere.py index 672d500..f0586d5 100644 --- a/parametric_cad/primitives/sphere.py +++ b/parametric_cad/primitives/sphere.py @@ -1,17 +1,15 @@ from parametric_cad.core import tm +from .base import Primitive -class Sphere: - def __init__(self, radius, subdivisions=3): + +class Sphere(Primitive): + def __init__(self, radius: float, subdivisions: int = 3) -> None: + super().__init__() self.radius = radius self.subdivisions = subdivisions - self._position = (0, 0, 0) - def at(self, x, y, z): - self._position = (x, y, z) - return self - - def mesh(self): - sph = tm.creation.icosphere(subdivisions=self.subdivisions, - radius=self.radius) - sph.apply_translation(self._position) - return sph + def _create_mesh(self) -> tm.Trimesh: + return tm.creation.icosphere( + subdivisions=self.subdivisions, + radius=self.radius, + ) diff --git a/parametric_cad/primitives/sprocket.py b/parametric_cad/primitives/sprocket.py index cc49826..eb9af61 100644 --- a/parametric_cad/primitives/sprocket.py +++ b/parametric_cad/primitives/sprocket.py @@ -1,11 +1,22 @@ -from math import cos, sin, pi -from parametric_cad.core import tm, safe_difference +from math import cos, pi, sin -class ChainSprocket: +from parametric_cad.core import safe_difference, tm +from .base import Primitive + + +class ChainSprocket(Primitive): """Simple chain sprocket for roller chain.""" - def __init__(self, pitch=12.7, roller_diameter=7.75, teeth=16, - thickness=5.0, bore_diameter=10.0, clearance=0.5): + def __init__( + self, + pitch: float = 12.7, + roller_diameter: float = 7.75, + teeth: int = 16, + thickness: float = 5.0, + bore_diameter: float = 10.0, + clearance: float = 0.5, + ) -> None: + super().__init__() self.pitch = float(pitch) self.roller_diameter = float(roller_diameter) self.teeth = int(teeth) @@ -14,23 +25,24 @@ class ChainSprocket: self.clearance = float(clearance) @property - def pitch_radius(self): + def pitch_radius(self) -> float: return self.pitch / (2 * sin(pi / self.teeth)) @property - def pitch_diameter(self): + def pitch_diameter(self) -> float: return self.pitch_radius * 2 - def mesh(self): + def _create_mesh(self) -> tm.Trimesh: # Base disc sized so pockets can be subtracted outer_radius = self.pitch_radius + self.roller_diameter / 2 + self.clearance - disc = tm.creation.cylinder(radius=outer_radius, height=self.thickness, - sections=self.teeth * 4) - + disc = tm.creation.cylinder( + radius=outer_radius, + height=self.thickness, + sections=self.teeth * 4, + ) disc.apply_translation([0, 0, self.thickness / 2]) - bore = tm.creation.cylinder(radius=self.bore_diameter / 2, - height=self.thickness + 0.1) + bore = tm.creation.cylinder(radius=self.bore_diameter / 2, height=self.thickness + 0.1) bore.apply_translation([0, 0, self.thickness / 2]) sprocket = safe_difference(disc, bore) @@ -40,9 +52,11 @@ class ChainSprocket: angle = 2 * pi * i / self.teeth x = cos(angle) * self.pitch_radius y = sin(angle) * self.pitch_radius - pocket = tm.creation.cylinder(radius=pocket_radius, - height=self.thickness + 0.1, - sections=16) + pocket = tm.creation.cylinder( + radius=pocket_radius, + height=self.thickness + 0.1, + sections=16, + ) pocket.apply_translation([x, y, self.thickness / 2]) pockets.append(pocket) diff --git a/tests/test_primitives.py b/tests/test_primitives.py index c15bb7c..95e75bb 100644 --- a/tests/test_primitives.py +++ b/tests/test_primitives.py @@ -3,6 +3,7 @@ import numpy as np from parametric_cad.core import tm, safe_difference, combine from math import cos, sin, pi +from parametric_cad.primitives.base import Primitive from parametric_cad.primitives.box import Box from parametric_cad.primitives.gear import SpurGear from parametric_cad.primitives.cylinder import Cylinder @@ -60,3 +61,11 @@ def test_chain_sprocket_properties_and_mesh(): mesh = sprocket.mesh() assert isinstance(mesh, tm.Trimesh) assert mesh.is_watertight + + +def test_primitive_inheritance(): + assert isinstance(Box(1, 1, 1), Primitive) + assert isinstance(Cylinder(1, 1), Primitive) + assert isinstance(Sphere(1), Primitive) + assert isinstance(SpurGear(module=1.0, teeth=8), Primitive) + assert isinstance(ChainSprocket(), Primitive)