Add Primitive base class and refactor primitives

This commit is contained in:
Richard Ward
2025-07-17 11:46:05 +01:00
parent ccba30ebec
commit 0d41cb1836
8 changed files with 159 additions and 86 deletions

View File

@@ -2,6 +2,7 @@
from .core import tm, safe_difference, combine
from .geometry import sg, Polygon, Point, box
from .primitives.base import Primitive
from .primitives.box import Box
from .primitives.cylinder import Cylinder
from .primitives.gear import SpurGear
@@ -15,6 +16,7 @@ __all__ = [
"sg",
"safe_difference",
"combine",
"Primitive",
"Box",
"Cylinder",
"Sphere",

View File

@@ -0,0 +1,34 @@
from typing import Optional, Sequence
from parametric_cad.core import tm
class Primitive:
"""Base class for simple parametric primitives."""
def __init__(self) -> None:
self._position = (0.0, 0.0, 0.0)
self._rotation: Optional[tuple[Sequence[float], float]] = None
def at(self, x: float, y: float, z: float):
"""Translate the primitive to ``(x, y, z)``."""
self._position = (x, y, z)
return self
def rotate(self, axis: Sequence[float], angle: float):
"""Rotate the primitive around ``axis`` by ``angle`` radians."""
self._rotation = (axis, angle)
return self
def _create_mesh(self) -> tm.Trimesh:
"""Return the untransformed mesh for this primitive."""
raise NotImplementedError
def mesh(self) -> tm.Trimesh:
mesh = self._create_mesh()
if self._rotation is not None:
axis, angle = self._rotation
rot = tm.transformations.rotation_matrix(angle, axis)
mesh.apply_transform(rot)
mesh.apply_translation(self._position)
return mesh

View File

@@ -1,17 +1,13 @@
from parametric_cad.core import tm
from .base import Primitive
class Box:
def __init__(self, width, depth, height):
class Box(Primitive):
def __init__(self, width: float, depth: float, height: float) -> None:
super().__init__()
self.width = width
self.depth = depth
self.height = height
self._position = (0, 0, 0)
def at(self, x, y, z):
self._position = (x, y, z)
return self
def mesh(self):
box = tm.creation.box(extents=(self.width, self.depth, self.height))
box.apply_translation(self._position)
return box
def _create_mesh(self) -> tm.Trimesh:
return tm.creation.box(extents=(self.width, self.depth, self.height))

View File

@@ -1,32 +1,19 @@
from parametric_cad.core import tm
from typing import Sequence, Optional
from typing import Sequence
class Cylinder:
def __init__(self, radius: float, height: float, sections: int = 32):
from parametric_cad.core import tm
from .base import Primitive
class Cylinder(Primitive):
def __init__(self, radius: float, height: float, sections: int = 32) -> None:
super().__init__()
self.radius = radius
self.height = height
self.sections = sections
self._position = (0.0, 0.0, 0.0)
self._rotation: Optional[Sequence[float]] = None
def at(self, x: float, y: float, z: float) -> "Cylinder":
self._position = (x, y, z)
return self
def rotate(self, axis: Sequence[float], angle: float) -> "Cylinder":
"""Rotate the cylinder around ``axis`` by ``angle`` radians."""
self._rotation = (axis, angle)
return self
def mesh(self) -> tm.Trimesh:
cyl = tm.creation.cylinder(
def _create_mesh(self) -> tm.Trimesh:
return tm.creation.cylinder(
radius=self.radius,
height=self.height,
sections=self.sections,
)
if self._rotation is not None:
axis, angle = self._rotation
rot = tm.transformations.rotation_matrix(angle, axis)
cyl.apply_transform(rot)
cyl.apply_translation(self._position)
return cyl

View File

@@ -1,15 +1,33 @@
import numpy as np
from parametric_cad.core import tm, safe_difference
from parametric_cad.geometry import Polygon
from math import pi, sin, cos, tan
import logging
from math import cos, pi, sin, tan
from typing import List
import numpy as np
from parametric_cad.core import safe_difference, tm
from parametric_cad.geometry import Polygon
from .base import Primitive
# Set up logging to file
logging.basicConfig(filename='gear_debug.log', level=logging.DEBUG,
format='%(asctime)s - %(levelname)s - %(message)s')
logging.basicConfig(
filename="gear_debug.log",
level=logging.DEBUG,
format="%(asctime)s - %(levelname)s - %(message)s",
)
class SpurGear:
def __init__(self, module, teeth, width=5.0, bore_diameter=5.0, hole_count=0, hole_diameter=2.0, hole_radius=None):
class SpurGear(Primitive):
def __init__(
self,
module: float,
teeth: int,
width: float = 5.0,
bore_diameter: float = 5.0,
hole_count: int = 0,
hole_diameter: float = 2.0,
hole_radius: float | None = None,
) -> None:
super().__init__()
self.module = module
self.teeth = teeth
self.width = width
@@ -17,32 +35,39 @@ class SpurGear:
self.hole_count = hole_count
self.hole_diameter = hole_diameter
self.hole_radius = hole_radius or (self.pitch_diameter / 2 + module * 1.5)
logging.debug(f"Initialized SpurGear: module={module}, teeth={teeth}, width={width}, bore={bore_diameter}, holes={hole_count}")
logging.debug(
"Initialized SpurGear: module=%s, teeth=%s, width=%s, bore=%s, holes=%s",
module,
teeth,
width,
bore_diameter,
hole_count,
)
@property
def pitch_diameter(self):
def pitch_diameter(self) -> float:
return self.module * self.teeth
@property
def base_diameter(self):
def base_diameter(self) -> float:
pressure_angle = 20 * pi / 180
return self.pitch_diameter * cos(pressure_angle)
@property
def addendum(self):
def addendum(self) -> float:
return self.module
@property
def dedendum(self):
def dedendum(self) -> float:
return 1.25 * self.module
def involute_profile(self, base_radius, outer_radius, steps=10):
def involute_profile(self, base_radius: float, outer_radius: float, steps: int = 10) -> np.ndarray:
theta = np.linspace(0, np.arccos(base_radius / outer_radius), steps)
x = base_radius * (np.cos(theta) + theta * np.tan(theta))
y = base_radius * (np.sin(theta) - theta * np.tan(theta))
return np.vstack((x, y)).T
def create_tooth(self):
def create_tooth(self) -> np.ndarray:
pitch_radius = self.pitch_diameter / 2
base_radius = self.base_diameter / 2
outer_radius = pitch_radius + self.addendum
@@ -52,7 +77,7 @@ class SpurGear:
mirrored = np.copy(involute)
mirrored[:, 1] *= -1
arc = []
arc: List[List[float]] = []
arc_steps = 10
start_angle = np.arctan2(mirrored[-1, 1], mirrored[-1, 0])
end_angle = -start_angle
@@ -65,18 +90,18 @@ class SpurGear:
profile = np.vstack([involute, arc, mirrored[::-1]])
if not np.allclose(profile[0], profile[-1]):
profile = np.vstack([profile, profile[0]])
logging.debug(f"Created tooth profile with {len(profile)} points")
logging.debug("Created tooth profile with %d points", len(profile))
return profile
def mesh(self):
def _create_mesh(self) -> tm.Trimesh:
tooth_profile = self.create_tooth()
polygon = Polygon(tooth_profile)
if not polygon.is_valid:
polygon = polygon.buffer(0)
logging.warning("Tooth polygon was invalid, repaired with buffer")
tooth_mesh = tm.creation.extrude_polygon(polygon, self.width, engine='triangle')
logging.debug(f"Extruded tooth mesh with {len(tooth_mesh.vertices)} vertices")
tooth_mesh = tm.creation.extrude_polygon(polygon, self.width, engine="triangle")
logging.debug("Extruded tooth mesh with %d vertices", len(tooth_mesh.vertices))
all_teeth = []
for i in range(self.teeth):
@@ -84,15 +109,19 @@ class SpurGear:
rot = tm.transformations.rotation_matrix(angle, [0, 0, 1])
rotated_tooth = tooth_mesh.copy().apply_transform(rot)
all_teeth.append(rotated_tooth)
logging.debug(f"Added tooth {i+1}/{self.teeth}")
logging.debug("Added tooth %d/%d", i + 1, self.teeth)
gear_body = tm.util.concatenate(all_teeth)
logging.debug(f"Combined {self.teeth} teeth into gear body with {len(gear_body.vertices)} vertices")
logging.debug(
"Combined %d teeth into gear body with %d vertices",
self.teeth,
len(gear_body.vertices),
)
bore = tm.creation.cylinder(radius=self.bore_diameter / 2, height=self.width + 0.1)
bore.apply_translation([0, 0, self.width / 2])
gear = safe_difference(gear_body, bore)
logging.debug(f"Subtracted bore, resulting mesh has {len(gear.vertices)} vertices")
logging.debug("Subtracted bore, resulting mesh has %d vertices", len(gear.vertices))
if self.hole_count > 0:
hole_cylinders = []
@@ -106,7 +135,11 @@ class SpurGear:
hole = hole.convex_hull
hole_cylinders.append(hole)
gear = safe_difference(gear, hole_cylinders)
logging.debug(f"Subtracted {self.hole_count} holes, resulting mesh has {len(gear.vertices)} vertices")
logging.debug(
"Subtracted %d holes, resulting mesh has %d vertices",
self.hole_count,
len(gear.vertices),
)
if not gear.is_watertight:
logging.warning("Final gear mesh is not watertight")

View File

@@ -1,17 +1,15 @@
from parametric_cad.core import tm
from .base import Primitive
class Sphere:
def __init__(self, radius, subdivisions=3):
class Sphere(Primitive):
def __init__(self, radius: float, subdivisions: int = 3) -> None:
super().__init__()
self.radius = radius
self.subdivisions = subdivisions
self._position = (0, 0, 0)
def at(self, x, y, z):
self._position = (x, y, z)
return self
def mesh(self):
sph = tm.creation.icosphere(subdivisions=self.subdivisions,
radius=self.radius)
sph.apply_translation(self._position)
return sph
def _create_mesh(self) -> tm.Trimesh:
return tm.creation.icosphere(
subdivisions=self.subdivisions,
radius=self.radius,
)

View File

@@ -1,11 +1,22 @@
from math import cos, sin, pi
from parametric_cad.core import tm, safe_difference
from math import cos, pi, sin
class ChainSprocket:
from parametric_cad.core import safe_difference, tm
from .base import Primitive
class ChainSprocket(Primitive):
"""Simple chain sprocket for roller chain."""
def __init__(self, pitch=12.7, roller_diameter=7.75, teeth=16,
thickness=5.0, bore_diameter=10.0, clearance=0.5):
def __init__(
self,
pitch: float = 12.7,
roller_diameter: float = 7.75,
teeth: int = 16,
thickness: float = 5.0,
bore_diameter: float = 10.0,
clearance: float = 0.5,
) -> None:
super().__init__()
self.pitch = float(pitch)
self.roller_diameter = float(roller_diameter)
self.teeth = int(teeth)
@@ -14,23 +25,24 @@ class ChainSprocket:
self.clearance = float(clearance)
@property
def pitch_radius(self):
def pitch_radius(self) -> float:
return self.pitch / (2 * sin(pi / self.teeth))
@property
def pitch_diameter(self):
def pitch_diameter(self) -> float:
return self.pitch_radius * 2
def mesh(self):
def _create_mesh(self) -> tm.Trimesh:
# Base disc sized so pockets can be subtracted
outer_radius = self.pitch_radius + self.roller_diameter / 2 + self.clearance
disc = tm.creation.cylinder(radius=outer_radius, height=self.thickness,
sections=self.teeth * 4)
disc = tm.creation.cylinder(
radius=outer_radius,
height=self.thickness,
sections=self.teeth * 4,
)
disc.apply_translation([0, 0, self.thickness / 2])
bore = tm.creation.cylinder(radius=self.bore_diameter / 2,
height=self.thickness + 0.1)
bore = tm.creation.cylinder(radius=self.bore_diameter / 2, height=self.thickness + 0.1)
bore.apply_translation([0, 0, self.thickness / 2])
sprocket = safe_difference(disc, bore)
@@ -40,9 +52,11 @@ class ChainSprocket:
angle = 2 * pi * i / self.teeth
x = cos(angle) * self.pitch_radius
y = sin(angle) * self.pitch_radius
pocket = tm.creation.cylinder(radius=pocket_radius,
height=self.thickness + 0.1,
sections=16)
pocket = tm.creation.cylinder(
radius=pocket_radius,
height=self.thickness + 0.1,
sections=16,
)
pocket.apply_translation([x, y, self.thickness / 2])
pockets.append(pocket)

View File

@@ -3,6 +3,7 @@ import numpy as np
from parametric_cad.core import tm, safe_difference, combine
from math import cos, sin, pi
from parametric_cad.primitives.base import Primitive
from parametric_cad.primitives.box import Box
from parametric_cad.primitives.gear import SpurGear
from parametric_cad.primitives.cylinder import Cylinder
@@ -60,3 +61,11 @@ def test_chain_sprocket_properties_and_mesh():
mesh = sprocket.mesh()
assert isinstance(mesh, tm.Trimesh)
assert mesh.is_watertight
def test_primitive_inheritance():
assert isinstance(Box(1, 1, 1), Primitive)
assert isinstance(Cylinder(1, 1), Primitive)
assert isinstance(Sphere(1), Primitive)
assert isinstance(SpurGear(module=1.0, teeth=8), Primitive)
assert isinstance(ChainSprocket(), Primitive)