001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *   http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 *
017 */
018
019package org.apache.commons.compress.utils;
020
021import java.io.File;
022import java.io.IOException;
023import java.nio.ByteBuffer;
024import java.nio.channels.ClosedChannelException;
025import java.nio.channels.NonWritableChannelException;
026import java.nio.channels.SeekableByteChannel;
027import java.nio.file.Files;
028import java.nio.file.Path;
029import java.nio.file.StandardOpenOption;
030import java.util.ArrayList;
031import java.util.Arrays;
032import java.util.Collections;
033import java.util.List;
034import java.util.Objects;
035
036/**
037 * Read-Only Implementation of {@link SeekableByteChannel} that
038 * concatenates a collection of other {@link SeekableByteChannel}s.
039 *
040 * <p>This is a lose port of <a
041 * href="https://github.com/frugalmechanic/fm-common/blob/master/jvm/src/main/scala/fm/common/MultiReadOnlySeekableByteChannel.scala">MultiReadOnlySeekableByteChannel</a>
042 * by Tim Underwood.</p>
043 *
044 * @since 1.19
045 */
046public class MultiReadOnlySeekableByteChannel implements SeekableByteChannel {
047
048    private static final Path[] EMPTY_PATH_ARRAY = {};
049
050    /**
051     * Concatenates the given files.
052     *
053     * @param files the files to concatenate
054     * @throws NullPointerException if files is null
055     * @throws IOException if opening a channel for one of the files fails
056     * @return SeekableByteChannel that concatenates all provided files
057     */
058    public static SeekableByteChannel forFiles(final File... files) throws IOException {
059        final List<Path> paths = new ArrayList<>();
060        for (final File f : Objects.requireNonNull(files, "files must not be null")) {
061            paths.add(f.toPath());
062        }
063
064        return forPaths(paths.toArray(EMPTY_PATH_ARRAY));
065    }
066
067    /**
068     * Concatenates the given file paths.
069     * @param paths the file paths to concatenate, note that the LAST FILE of files should be the LAST SEGMENT(.zip)
070     * and these files should be added in correct order (e.g.: .z01, .z02... .z99, .zip)
071     * @return SeekableByteChannel that concatenates all provided files
072     * @throws NullPointerException if files is null
073     * @throws IOException if opening a channel for one of the files fails
074     * @throws IOException if the first channel doesn't seem to hold
075     * the beginning of a split archive
076     * @since 1.22
077     */
078    public static SeekableByteChannel forPaths(final Path... paths) throws IOException {
079        final List<SeekableByteChannel> channels = new ArrayList<>();
080        for (final Path path : Objects.requireNonNull(paths, "paths must not be null")) {
081            channels.add(Files.newByteChannel(path, StandardOpenOption.READ));
082        }
083        if (channels.size() == 1) {
084            return channels.get(0);
085        }
086        return new MultiReadOnlySeekableByteChannel(channels);
087    }
088
089    /**
090     * Concatenates the given channels.
091     *
092     * @param channels the channels to concatenate
093     * @throws NullPointerException if channels is null
094     * @return SeekableByteChannel that concatenates all provided channels
095     */
096    public static SeekableByteChannel forSeekableByteChannels(final SeekableByteChannel... channels) {
097        if (Objects.requireNonNull(channels, "channels must not be null").length == 1) {
098            return channels[0];
099        }
100        return new MultiReadOnlySeekableByteChannel(Arrays.asList(channels));
101    }
102
103    private final List<SeekableByteChannel> channels;
104
105    private long globalPosition;
106
107    private int currentChannelIdx;
108
109    /**
110     * Concatenates the given channels.
111     *
112     * @param channels the channels to concatenate
113     * @throws NullPointerException if channels is null
114     */
115    public MultiReadOnlySeekableByteChannel(final List<SeekableByteChannel> channels) {
116        this.channels = Collections.unmodifiableList(new ArrayList<>(
117            Objects.requireNonNull(channels, "channels must not be null")));
118    }
119
120    @Override
121    public void close() throws IOException {
122        IOException first = null;
123        for (final SeekableByteChannel ch : channels) {
124            try {
125                ch.close();
126            } catch (final IOException ex) {
127                if (first == null) {
128                    first = ex;
129                }
130            }
131        }
132        if (first != null) {
133            throw new IOException("failed to close wrapped channel", first);
134        }
135    }
136
137    @Override
138    public boolean isOpen() {
139        return channels.stream().allMatch(SeekableByteChannel::isOpen);
140    }
141
142    /**
143     * Gets this channel's position.
144     *
145     * <p>This method violates the contract of {@link SeekableByteChannel#position()} as it will not throw any exception
146     * when invoked on a closed channel. Instead it will return the position the channel had when close has been
147     * called.</p>
148     */
149    @Override
150    public long position() {
151        return globalPosition;
152    }
153
154    @Override
155    public synchronized SeekableByteChannel position(final long newPosition) throws IOException {
156        if (newPosition < 0) {
157            throw new IllegalArgumentException("Negative position: " + newPosition);
158        }
159        if (!isOpen()) {
160            throw new ClosedChannelException();
161        }
162
163        globalPosition = newPosition;
164
165        long pos = newPosition;
166
167        for (int i = 0; i < channels.size(); i++) {
168            final SeekableByteChannel currentChannel = channels.get(i);
169            final long size = currentChannel.size();
170
171            final long newChannelPos;
172            if (pos == -1L) {
173                // Position is already set for the correct channel,
174                // the rest of the channels get reset to 0
175                newChannelPos = 0;
176            } else if (pos <= size) {
177                // This channel is where we want to be
178                currentChannelIdx = i;
179                final long tmp = pos;
180                pos = -1L; // Mark pos as already being set
181                newChannelPos = tmp;
182            } else {
183                // newPosition is past this channel.  Set channel
184                // position to the end and substract channel size from
185                // pos
186                pos -= size;
187                newChannelPos = size;
188            }
189
190            currentChannel.position(newChannelPos);
191        }
192        return this;
193    }
194
195    /**
196     * Sets the position based on the given channel number and relative offset
197     *
198     * @param channelNumber  the channel number
199     * @param relativeOffset the relative offset in the corresponding channel
200     * @return global position of all channels as if they are a single channel
201     * @throws IOException if positioning fails
202     */
203    public synchronized SeekableByteChannel position(final long channelNumber, final long relativeOffset) throws IOException {
204        if (!isOpen()) {
205            throw new ClosedChannelException();
206        }
207        long globalPosition = relativeOffset;
208        for (int i = 0; i < channelNumber; i++) {
209            globalPosition += channels.get(i).size();
210        }
211
212        return position(globalPosition);
213    }
214
215    @Override
216    public synchronized int read(final ByteBuffer dst) throws IOException {
217        if (!isOpen()) {
218            throw new ClosedChannelException();
219        }
220        if (!dst.hasRemaining()) {
221            return 0;
222        }
223
224        int totalBytesRead = 0;
225        while (dst.hasRemaining() && currentChannelIdx < channels.size()) {
226            final SeekableByteChannel currentChannel = channels.get(currentChannelIdx);
227            final int newBytesRead = currentChannel.read(dst);
228            if (newBytesRead == -1) {
229                // EOF for this channel -- advance to next channel idx
230                currentChannelIdx += 1;
231                continue;
232            }
233            if (currentChannel.position() >= currentChannel.size()) {
234                // we are at the end of the current channel
235                currentChannelIdx++;
236            }
237            totalBytesRead += newBytesRead;
238        }
239        if (totalBytesRead > 0) {
240            globalPosition += totalBytesRead;
241            return totalBytesRead;
242        }
243        return -1;
244    }
245
246    @Override
247    public long size() throws IOException {
248        if (!isOpen()) {
249            throw new ClosedChannelException();
250        }
251        long acc = 0;
252        for (final SeekableByteChannel ch : channels) {
253            acc += ch.size();
254        }
255        return acc;
256    }
257
258    /**
259     * @throws NonWritableChannelException since this implementation is read-only.
260     */
261    @Override
262    public SeekableByteChannel truncate(final long size) {
263        throw new NonWritableChannelException();
264    }
265
266    /**
267     * @throws NonWritableChannelException since this implementation is read-only.
268     */
269    @Override
270    public int write(final ByteBuffer src) {
271        throw new NonWritableChannelException();
272    }
273
274}