Skip to content

Commit

Permalink
feat: initial binary op implementations
Browse files Browse the repository at this point in the history
  • Loading branch information
lbressler13 committed Mar 1, 2024
1 parent 24095dd commit 49d072d
Show file tree
Hide file tree
Showing 8 changed files with 123 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ sealed class Expression : Number() {

// abstract fun isZero(): Boolean

// abstract operator fun plus(other: Expression): Expression
// abstract operator fun minus(other: Expression): Expression
// abstract operator fun times(other: Expression): Expression
// abstract operator fun div(other: Expression): Expression
abstract operator fun plus(other: Expression): Expression
abstract operator fun minus(other: Expression): Expression
abstract operator fun times(other: Expression): Expression
abstract operator fun div(other: Expression): Expression

abstract fun toTerm(): Term
abstract fun getSimplified(): Expression
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,20 @@ import xyz.lbres.exactnumbers.utils.castToFloat
import xyz.lbres.exactnumbers.utils.castToInt
import xyz.lbres.exactnumbers.utils.castToLong
import xyz.lbres.exactnumbers.utils.castToShort
import xyz.lbres.exactnumbers.utils.divideByZero

// internal implementation of expression
@Suppress("EqualsOrHashCode")
internal abstract class ExpressionImpl : Expression() {
override fun minus(other: Expression): Expression = plus(-other)

override fun div(other: Expression): Expression {
if (other == ZERO) {
throw divideByZero
}
return times(other.inverse())
}

override fun equals(other: Any?): Boolean = other is Expression && getValue() == other.getValue()

override fun toByte(): Byte = castToByte(getValue(), this, "Expression")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import java.math.BigDecimal
* Can also represent 1/sum
*/
@Suppress("EqualsOrHashCode")
internal class AdditiveExpression private constructor(private val expressions: ConstMultiSet<Expression>, private val isInverted: Boolean) : ExpressionImpl() {
internal class AdditiveExpression private constructor(val expressions: ConstMultiSet<Expression>, private val isInverted: Boolean) : ExpressionImpl() {
private var term: Term? = null

init {
Expand All @@ -27,6 +27,7 @@ internal class AdditiveExpression private constructor(private val expressions: C
}

constructor(expr1: Expression, expr2: Expression) : this(constMultiSetOf(expr1, expr2), false)
constructor(expressions: ConstMultiSet<Expression>) : this(expressions, false)

override fun unaryPlus(): Expression = this
override fun unaryMinus(): Expression {
Expand All @@ -53,6 +54,17 @@ internal class AdditiveExpression private constructor(private val expressions: C
// TODO
override fun getSimplified(): Expression = this

override fun plus(other: Expression): AdditiveExpression {
if (other is AdditiveExpression && !isInverted && !other.isInverted) {
val newExpressions = (expressions + other.expressions).toConstMultiSet()
return AdditiveExpression(newExpressions, false)
}

return AdditiveExpression((expressions + constMultiSetOf(other)).toConstMultiSet(), false)
}

override fun times(other: Expression): Expression = MultiplicativeExpression(this, other)

override fun hashCode(): Int = createHashCode(listOf(expressions, "AdditiveExpression"))

override fun toString(): String = "(${expressions.joinToString("+")})"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,12 @@ import xyz.lbres.exactnumbers.utils.createHashCode
import xyz.lbres.exactnumbers.utils.getOrSet
import xyz.lbres.kotlinutils.bigdecimal.ext.isZero
import xyz.lbres.kotlinutils.collection.ext.toConstMultiSet
import xyz.lbres.kotlinutils.iterable.ext.forEachWith
import xyz.lbres.kotlinutils.set.multiset.anyConsistent
import xyz.lbres.kotlinutils.set.multiset.const.ConstMultiSet
import xyz.lbres.kotlinutils.set.multiset.const.ConstMutableMultiSet
import xyz.lbres.kotlinutils.set.multiset.const.constMultiSetOf
import xyz.lbres.kotlinutils.set.multiset.const.constMutableMultiSetOf
import xyz.lbres.kotlinutils.set.multiset.mapToSetConsistent
import java.math.BigDecimal

Expand Down Expand Up @@ -48,9 +51,64 @@ internal class MultiplicativeExpression private constructor(expressions: ConstMu
}
}

override fun getSimplified(): Expression = SimpleExpression(toTerm())
private fun getSplitExpressions(): Pair<ConstMultiSet<SimpleExpression>, ConstMultiSet<AdditiveExpression>> {
val simple: ConstMutableMultiSet<SimpleExpression> = constMutableMultiSetOf()
val additive: ConstMutableMultiSet<AdditiveExpression> = constMutableMultiSetOf()

expressions.forEach {
when (it) {
is SimpleExpression -> simple.add(it)
is AdditiveExpression -> additive.add(it)
is MultiplicativeExpression -> {
val split = it.getSplitExpressions()
simple.addAll(split.first)
additive.addAll(split.second)
}
}
}

return Pair(simple, additive)
}

// override fun getSimplified(): Expression = SimpleExpression(toTerm())
override fun getSimplified(): Expression {
val split = getSplitExpressions()
val simpleTerm = split.first.fold(Term.ONE) { acc, expr -> acc * expr.toTerm() }
val simple: Expression = SimpleExpression(simpleTerm.getSimplified())

if (split.second.isEmpty()) {
return simple
}

// TODO extract coefficient for each additive expr

val exprs = split.second.fold(constMultiSetOf(simple)) { acc, additiveExpr ->
val distributed: ConstMutableMultiSet<Expression> = constMutableMultiSetOf()
acc.forEachWith(additiveExpr.expressions) { expr1, expr2 ->
distributed.add(expr1 * expr2)
}
distributed
}
return AdditiveExpression(exprs.toConstMultiSet()).getSimplified()
}

override fun getValue(): BigDecimal = getSimplified().getValue()

override fun plus(other: Expression): Expression = AdditiveExpression(this, other)

override fun times(other: Expression): Expression {
return when (other) {
ZERO -> ZERO
ONE -> this
is MultiplicativeExpression -> {
MultiplicativeExpression((expressions + other.expressions).toConstMultiSet())
}
else -> {
MultiplicativeExpression((expressions + constMultiSetOf(other)).toConstMultiSet())
}
}
}

override fun hashCode(): Int = createHashCode(listOf(expressions, "MultiplicativeExpression"))

override fun toString(): String = "(${expressions.joinToString("x")})"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,43 @@ import xyz.lbres.exactnumbers.expressions.Expression
import xyz.lbres.exactnumbers.expressions.ExpressionImpl
import xyz.lbres.exactnumbers.expressions.term.Term
import xyz.lbres.exactnumbers.utils.createHashCode
import xyz.lbres.exactnumbers.utils.getOrSet
import java.math.BigDecimal

/**
* Expression consisting of a single term
*/
@Suppress("EqualsOrHashCode")
internal class SimpleExpression(private val term: Term) : ExpressionImpl() {
private var simplified: Expression? = null

override fun unaryPlus(): Expression = this
override fun unaryMinus(): Expression = SimpleExpression(-term)
override fun inverse(): Expression = SimpleExpression(term.inverse())

override fun toTerm(): Term = term
override fun getSimplified(): Expression = getOrSet({ simplified }, { simplified = it }) { SimpleExpression(term.getSimplified()) }
override fun getSimplified(): SimpleExpression = SimpleExpression(term.getSimplified())
override fun getValue(): BigDecimal = term.getValue()

override fun plus(other: Expression): Expression {
if (other is SimpleExpression) {
val simplified = term.getSimplified()
val otherSimplified = other.term.getSimplified()
if (simplified.factors == otherSimplified.factors) {
val newCoefficient = simplified.coefficient + otherSimplified.coefficient
return SimpleExpression(Term.fromValues(newCoefficient, simplified.factors))
}
}

return AdditiveExpression(this, other)
}

override fun times(other: Expression): Expression {
return if (other is SimpleExpression) {
SimpleExpression(this.term * other.term)
} else {
MultiplicativeExpression(this, other)
}
}

override fun hashCode(): Int = createHashCode(listOf(term, "Expression"))

override fun toString(): String = "($term)"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ class AdditiveExpressionTest {
// @Test fun testInverse() = runInverseTests() // TODO
// @Test fun testGetValue() = runGetValueTests() // TODO

// @Test fun testPlus() = runPlusTests() // TODO
// @Test fun testMinus() = runMinusTests() // TODO
// @Test fun testTimes() = runTimesTests() // TODO
// @Test fun testDiv() = runDivTests() // TODO

// @Test fun testToTerm() = runToTermTests() // TODO
// @Test fun testToByte() = runToByteTests() // TODO
// @Test fun testToChar() = runToCharTests() // TODO
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ class MultiplicativeExpressionTest {
@Test fun testGetValue() = runGetValueTests()
@Test fun testGetSimplified() = runGetSimplifiedTests()

// @Test fun testPlus() = runPlusTests() // TODO
// @Test fun testMinus() = runMinusTests() // TODO
// @Test fun testTimes() = runTimesTests() // TODO
// @Test fun testDiv() = runDivTests() // TODO

@Test fun testToTerm() = runToTermTests()
@Test fun testToByte() = runToByteTests()
@Test fun testToChar() = runToCharTests()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ class SimpleExpressionTest {
@Test fun testGetValue() = runGetValueTests()
@Test fun testGetSimplified() = runGetSimplifiedTests()

// @Test fun testPlus() = runPlusTests() // TODO
// @Test fun testMinus() = runMinusTests() // TODO
// @Test fun testTimes() = runTimesTests() // TODO
// @Test fun testDiv() = runDivTests() // TODO

@Test fun testToTerm() = runToTermTests()
@Test fun testToByte() = runToByteTests()
@Test fun testToChar() = runToCharTests()
Expand Down

0 comments on commit 49d072d

Please sign in to comment.