//***************************************************************************** // Copyright 2017-2019 Intel Corporation // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. //***************************************************************************** #include <cstdint> #include <cstdio> #include <memory> #include <string> class MNistLoader { protected: MNistLoader(const std::string& filename, uint32_t magic); virtual ~MNistLoader(); virtual void read_header(); public: void open(); void close(); void reset(); template <typename T> void read(T* loc, size_t n = 1); void read_scaled(float* loc, size_t n); template <typename T> size_t file_read(T* loc, size_t n) { return fread(loc, sizeof(T), n, m_file); } uint32_t get_items() { return m_items; } protected: std::string m_filename; FILE* m_file{nullptr}; uint32_t m_magic; uint32_t m_items{0}; fpos_t m_data_pos{0}; }; class MNistImageLoader : public MNistLoader { static const uint32_t magic_value = 0x00000803; virtual void read_header() override; public: MNistImageLoader(const std::string& file); static const char* const TEST; static const char* const TRAIN; uint32_t get_rows() { return m_rows; } uint32_t get_columns() { return m_columns; } protected: uint32_t m_rows{0}; uint32_t m_columns{0}; }; class MNistLabelLoader : public MNistLoader { static const uint32_t magic_value = 0x00000801; public: MNistLabelLoader(const std::string& file); static const char* TEST; static const char* TRAIN; }; class MNistDataLoader { public: MNistDataLoader(size_t batch_size, const std::string& image, const std::string& label); ~MNistDataLoader(); void open(); void close(); uint32_t get_rows() { return m_image_loader.get_rows(); } uint32_t get_columns() { return m_image_loader.get_columns(); } size_t get_batch_size() { return m_batch_size; } size_t get_items() { return m_items; } size_t get_epoch() { return m_epoch; } size_t get_pos() { return m_pos; } void load(); void rewind(); void reset(); const float* get_image_floats() const { return m_image_floats.get(); } const float* get_label_floats() const { return m_label_floats.get(); } size_t get_image_batch_size() const { return m_image_sample_size * m_batch_size; } size_t get_label_batch_size() const { return m_batch_size; } protected: size_t m_batch_size; MNistImageLoader m_image_loader; MNistLabelLoader m_label_loader; int32_t m_items{0}; size_t m_pos{0}; size_t m_epoch{0}; std::unique_ptr<float[]> m_image_floats; std::unique_ptr<float[]> m_label_floats; size_t m_image_sample_size{0}; };