diff --git a/matrix.hpp b/matrix.hpp index 3f844af..37250c2 100644 --- a/matrix.hpp +++ b/matrix.hpp @@ -11,6 +11,7 @@ private: uint64_t entry_dimension; vector *entries; bool err; + bool augmented = false; public: matrix(const uint64_t num_entries, const uint64_t entry_dimension) { @@ -38,6 +39,24 @@ public: this->entries = NULL; this->err = true; } + const matrix augment(vector v) const { + matrix m = *this; + if (v.get_dimention() != this->get_num_entries()) { + m.err = true; + return m; + } + for (uint64_t i = 0; i < m.get_num_entries(); i++) { + vector old = m[i]; + vector n(old.get_dimention() + 1); + for (uint64_t j = 0; j < old.get_dimention(); j++) + n[j] = old[j]; + n[old.get_dimention()] = v[i]; + m[i] = n; + } + m.entry_dimension = m.entry_dimension + 1; + m.augmented = true; + return m; + } const uint64_t get_num_entries() const { return this->num_entries; } const uint64_t get_entry_dimension() const { return this->entry_dimension; } const cnumber determinant() const { @@ -81,30 +100,30 @@ public: // switch row i with row j const matrix exchange_row(uint64_t i, uint64_t j) { - matrix m = *this; - vector v = m.get_entry(i); - m[i] = m[j]; - m[j] = v; - return m; + matrix m = *this; + vector v = m.get_entry(i); + m[i] = m[j]; + m[j] = v; + return m; } // subtract row i from row j const matrix subtract_row(uint64_t i, uint64_t j, cnumber multiplier = 1) { - matrix m = *this; - m[j] = m[j] - (m.multiply_row(i, multiplier)[i]); - return m; + matrix m = *this; + m[j] = m[j] - (m.multiply_row(i, multiplier)[i]); + return m; } // add row i to row j const matrix add_row(uint64_t i, uint64_t j, cnumber multiplier = 1) { - matrix m = *this; - m[j] = m[j] + (m.multiply_row(i, multiplier)[i]); - return m; + matrix m = *this; + m[j] = m[j] + (m.multiply_row(i, multiplier)[i]); + return m; } // Multiply row by z const matrix multiply_row(uint64_t i, cnumber z) { - matrix m = *this; - m[i] = m[i] * z; - return m; + matrix m = *this; + m[i] = m[i] * z; + return m; } const vector get_entry(uint64_t index) const { return this->entries[index]; } @@ -217,13 +236,15 @@ public: oss << std::setw(padding) << "|"; symbols[1] = oss.str(); symbols[2] = "|"; - for (int i = 0; i < 3; i++) { - int len = symbols[i].length() - 1; - char cur = symbols[i][0]; - if (cur != last) { - os << symbols[i]; + bool print = true; + for (int k = 0; k < 3; k++) { + int len = symbols[k].length() - 1; + char cur = symbols[k][0]; + if (cur != last || (m.augmented && j == m[i].get_dimention() - 1 && print)) { + print = false; + os << symbols[k]; } - last = symbols[i][len]; + last = symbols[k][len]; } } if (i != m.num_entries - 1)