Skip to content

File Rays.h

File List > Intern > rayx-core > src > Rays.h

Go to the documentation of this file

#pragma once

#include <numeric>
#include <vector>

#include "Debug/Instrumentor.h"
#include "RayAttrMask.h"

namespace rayx {

struct RAYX_API Rays {
  protected:
    Rays(const Rays&) = default;
    Rays& operator=(const Rays&) = default;

  public:
    Rays()                  = default;
    Rays(Rays&&)            = default;
    Rays& operator=(Rays&&) = default;

#define X(type, name, flag) std::vector<type> name;

    RAYX_X_MACRO_RAY_ATTR
#undef X

    glm::dvec3 position(const int i) const { return glm::dvec3(position_x[i], position_y[i], position_z[i]); }
    void position(const int i, const glm::dvec3 position) {
        position_x[i] = position.x;
        position_y[i] = position.y;
        position_z[i] = position.z;
    }

    glm::dvec3 direction(const int i) const { return glm::dvec3(direction_x[i], direction_y[i], direction_z[i]); }
    void direction(const int i, const glm::dvec3 direction) {
        direction_x[i] = direction.x;
        direction_y[i] = direction.y;
        direction_z[i] = direction.z;
    }

    ElectricField electric_field(const int i) const { return ElectricField(electric_field_x[i], electric_field_y[i], electric_field_z[i]); }
    void electric_field(const int i, const ElectricField electric_field) {
        electric_field_x[i] = electric_field.x;
        electric_field_y[i] = electric_field.y;
        electric_field_z[i] = electric_field.z;
    }

    [[nodiscard]] Rays copy() const;

    RayAttrMask attrMask() const;

    bool contains(const RayAttrMask attr) const;

    bool empty() const;

    int size() const;

    int numPaths() const;

    Rays& append(const Rays& other);

    [[nodiscard]] static Rays concat(const std::vector<Rays>& rays_list);

    [[nodiscard]] Rays sortByObjectId() const;

    [[nodiscard]] Rays sortByPathIdAndPathEventId() const;

    template <typename Compare>
    [[nodiscard]] Rays sort(Compare comp) const;

    Rays& filterByAttrMask(const RayAttrMask mask);

    [[nodiscard]] Rays filterByObjectId(const int object_id) const;

    [[nodiscard]] Rays filterByLastEventInPath() const;

    template <typename Pred>
    [[nodiscard]] Rays filter(Pred pred) const;

    template <typename Pred>
    int count(Pred pred) const;

    bool isValid() const;

    // TODO: implement helper methods to iterate over attributes, to get rid of most of the X-macros
};

template <typename Compare>
Rays Rays::sort(Compare comp) const {
    RAYX_PROFILE_FUNCTION_STDOUT();

    const auto attr = attrMask();
    const auto n    = size();

    auto indices = std::vector<int>(n);
    std::iota(indices.begin(), indices.end(), 0);
    std::sort(indices.begin(), indices.end(), comp);

    Rays result;
#define X(type, name, flag)                                            \
    if (!!(attr & RayAttrMask::flag)) {                                \
        result.name.resize(name.size());                               \
        for (int i = 0; i < n; ++i) result.name[i] = name[indices[i]]; \
    }
    RAYX_X_MACRO_RAY_ATTR
#undef X

    return result;
}

template <typename Pred>
Rays Rays::filter(Pred pred) const {
    RAYX_PROFILE_FUNCTION_STDOUT();

    const auto attr = attrMask();
    const auto n    = size();

    auto indices = std::vector<int>{};
    for (int i = 0; i < n; ++i)
        if (pred(i)) indices.push_back(i);

    Rays result;
#define X(type, name, flag)                                                                                           \
    if (!!(attr & RayAttrMask::flag)) {                                                                               \
        result.name.resize(indices.size());                                                                           \
        std::transform(indices.begin(), indices.end(), result.name.begin(), [this](const int i) { return name[i]; }); \
    }
    RAYX_X_MACRO_RAY_ATTR
#undef X

    return result;
}

template <typename Pred>
int Rays::count(Pred pred) const {
    const int sz = size();
    int count    = 0;
    for (int i = 0; i < sz; ++i)
        if (pred(i)) ++count;
    return count;
}

static_assert(std::is_nothrow_move_constructible_v<Rays>);  // ensure efficient moves, when used in std::vector<Rays>

bool RAYX_API operator==(const Rays& lhs, const Rays& rhs);
bool RAYX_API operator!=(const Rays& lhs, const Rays& rhs);

}  // namespace rayx